# Evaluating downstream model against pre-trained Self-Supervised models

In [1]:
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform
from torch.utils.data import Dataset, DataLoader, random_split
from pl_bolts.models.self_supervised import SimCLR, BYOL
# from livelossplot import PlotLosses
import torchvision.transforms as T
# from sklearn.neighbors import KNeighborsClassifier
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score
import plotly.express as px
import pytorch_lightning as pl
from time import time
from PIL import Image
import torchmetrics
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

  warn(
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  warn_missing_pkg("gym")


In [2]:
def get_best_checkpoint(selected_model):
    logs_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/',
        selected_model,
        'lightning_logs'
    )

    best_version = max(
        map(
            lambda x: int(x.replace('version_', '')) if 'version' in x else 0,
            os.listdir(logs_dir)
        )
    )

    version_dir = os.path.join(logs_dir, f'version_{best_version}', 'checkpoints')
    best_checkpoint = os.path.join(version_dir, os.listdir(version_dir)[0])
    print('LATEST CHECKPOINT', best_checkpoint)

    return best_checkpoint

In [3]:
class ICDARDataset(Dataset):

    def __init__(self, csv_filepath, root_dir, transforms=None, convert_rgb=True):
        self.root_dir = root_dir
        self.transforms = transforms
        self.data = pd.read_csv(csv_filepath, sep=';')
        self.convert_rgb = convert_rgb

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = os.path.join(self.root_dir, self.data.loc[idx, 'FILENAME'])
        
        try:
            image = Image.open(img_path)
        except Exception as ex:
            return None

        if self.convert_rgb:
            image = image.convert('RGB')
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return image, self.data.loc[idx, 'SCRIPT_TYPE']


In [4]:
def load_best_checkpoint(model_name, **model_kwargs):
    
    checkpoint = get_best_checkpoint(model_name)
    print(model_kwargs)
    
    if model_name == 'SimCLR':
        model = SimCLR.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
        return model.encoder
    elif model_name == 'SimCLRDownstream':
        model = DownstreamClassifier.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
        return model
    else:
        model = SimCLR.load_from_checkpoint(checkpoint, strict=False)
        return embeddings_model

In [5]:
def data_factory(dataset_name, root_dir, label_filepath, transforms, mode, batch_size, collate_fn=None, num_cpus=None):

    if dataset_name.lower() == 'icdar':
        dataset = ICDARDataset(label_filepath, root_dir, transforms=transforms(), convert_rgb=True)
    else:
        raise NotImplementedError(f'Dataset {dataset_name} is not implemented')

    total_count = len(dataset)
    train_count = int(0.7 * total_count)
    val_count = int(0.2 * total_count)
    test_count = total_count - train_count - val_count

    train_dataset, val_dataset, test_dataset = random_split(
        dataset,
        (train_count, val_count, test_count),
        generator=torch.Generator().manual_seed(42)
    )

    if mode in 'train':
        return {
            'train': DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
                persistent_workers=True,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            ),
            'val': DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                persistent_workers=True,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            )
        }
    elif mode == 'test':
        return {
            'test': DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                persistent_workers=True,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            )
        }
    else:
        raise KeyError(f'Unknown mode: {mode}')


In [6]:
# root_dir = os.path.join(
#     '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Training'
# )

# train_dataloaders = data_factory(
#     dataset_name='icdar',
#     root_dir=root_dir,
#     label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
#     transforms=SimCLREvalDataTransform,
#     mode='train',
#     batch_size=256,
#     num_cpus=4
# )

# test_dataloaders = data_factory(
#     dataset_name='icdar',
#     root_dir=root_dir, 
#     label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
#     transforms=SimCLREvalDataTransform,
#     mode='test',
#     batch_size=256,
#     num_cpus=4
# )

In [7]:
def plot_features(model, data_loader, num_feats, batch_size, num_samples, perplexity=25):
    num_samples = len(data_loader) if not num_samples else num_samples
    feats = np.array([]).reshape((0, num_feats))
    labels = np.array([])
    model.eval()
    model.cuda()

    processed_samples = 0
    with torch.no_grad():
        for (x1, x2, _), label in data_loader:
            if processed_samples >= num_samples:
                break
            x1 = x1.squeeze().cuda()
            out = model(x1)
            out = out[-1].detach().cpu().numpy()
            print(out.shape)
            feats = np.append(feats, out, axis=0)
            labels = np.append(labels, label, axis=0)
            processed_samples += batch_size

    tsne = TSNE(n_components=3, perplexity=perplexity, init='pca')
    x_feats = tsne.fit_transform(feats)

    dim_red_df = pd.DataFrame(x_feats)
    dim_red_df['labels'] = pd.Categorical(labels)
    fig = px.scatter_3d(dim_red_df, x=0, y=1, z=2, color='labels', size_max=5)
    fig.show()


In [8]:
def generate_from_embeddings(model, dataloader):
    
    X = []
    y = []
    
    for images, labels in dataloader:
        x1, x2, _ = images
        x1 = x1.to('cuda')
        with torch.no_grad():
            embeddings = model(x1)[-1].detach().cpu().numpy()
        X.append(embeddings)
        y.append(labels.numpy())
        
    X = np.concatenate(X)
    y = np.concatenate(y)

    return X, y

In [9]:
def cluster_embeddings():
    
    simclr_encoder = load_best_checkpoint('SimCLR')
    simclr_encoder.eval()
    simclr_encoder.cuda()
    
    X_train, y_train = generate_from_embeddings(simclr_encoder, train_dataloaders.get('train'))
    
    knn = KNeighborsClassifier(n_neighbors=10)
    print(knn)
    knn.fit(X_train, y_train)
    
    X_test, y_test = generate_from_embeddings(simclr_encoder, test_dataloaders.get('test'))
    y_pred = knn.predict(X_test)
    
    print(accuracy_score(y_test, y_pred))
    
    return knn
    
#     plot_features(simclr_encoder, dataloaders.get('train'), 2048, 256, 1000, 1000)

In [None]:
class DownstreamClassifier(pl.LightningModule):
    
    def __init__(self, base_model_name='SimCLR', features=2048, num_classes=13, learning_rate=1e-2):
        print(base_model_name, features, num_classes)
        super().__init__()
        
        self.save_hyperparameters()
                
        self.learning_rate = learning_rate
        
        self.num_classes = num_classes
        
        self.base_model = load_best_checkpoint(base_model_name, num_classes=num_classes)

        self.classifier = nn.Sequential(
            nn.Linear(features, num_classes)
        )
        
    def forward(self, x):
        with torch.no_grad():
            x = self.base_model(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        loss = torch.nn.CrossEntropyLoss()(y_hat, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        loss = torch.nn.CrossEntropyLoss()(y_hat, label)
        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        acc_metric = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
        acc = acc_metric(y_hat, label)
        self.log('test_acc', acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


In [11]:
def train_downstream_model(model_name, max_epochs=10):
    
    root_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Training'
    )
    dataloaders = data_factory(
        dataset_name='icdar',
        root_dir=root_dir, 
        label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
        transforms=SimCLREvalDataTransform,
        mode='train',
        batch_size=64,
        num_cpus=4
    )
    
    trainer = pl.Trainer(
        default_root_dir=os.path.abspath(os.path.join(root_dir, '..', 'trained_models', 'SimCLRDownstream')),
        accelerator='gpu',
        devices=-1,
        max_epochs=max_epochs,
        enable_progress_bar=True,
        precision=16,
        callbacks=[
            pl.callbacks.ModelCheckpoint(mode="min", monitor="val_loss"),
            pl.callbacks.RichProgressBar()
        ]
    )
    
    downstream_classifier = DownstreamClassifier(model_name, 2048, 13)
    
    trainer.fit(downstream_classifier, dataloaders.get('train'), dataloaders.get('val'))
    
    return downstream_classifier

In [14]:
downstream_model = train_downstream_model('SimCLR', 100)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


SimCLR 2048 13
LATEST CHECKPOINT /home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/SimCLR/lightning_logs/version_584360/checkpoints/epoch=474-step=4275.ckpt
{'num_classes': 13}


  obj = cls(**_cls_kwargs)
  return backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False)
  return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
  model = ResNet(block, layers, **kwargs)
  conv1x1(self.inplanes, planes * block.expansion, stride),
  block(
  self.conv2 = conv3x3(width, width, stride, groups, dilation)
  self.projection = Projection(input_dim=self.hidden_mlp, hidden_dim=self.hidden_mlp, output_dim=self.feat_dim)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [14]:
def test_downstream_model(model_name):
    
    root_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Training'
    )
    dataloaders = data_factory(
        dataset_name='icdar',
        root_dir=root_dir, 
        label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
        transforms=SimCLREvalDataTransform,
        mode='test',
        batch_size=64,
        num_cpus=4
    )
    
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=-1,
        max_epochs=1,
        enable_progress_bar=True,
        precision=16,
        enable_checkpointing=False,
        callbacks=[pl.callbacks.RichProgressBar()]
    )
    
    downstream_classifier = load_best_checkpoint(model_name, base_model_name='MAE', features=2048, num_classes=13)
    
    trainer.test(downstream_classifier, dataloaders.get('test'))

In [None]:
test_downstream_model('SimCLRDownstream')

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


LATEST CHECKPOINT /home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/SimCLRDownstream/lightning_logs/version_595294/checkpoints/epoch=5-step=228.ckpt
{'base_model_name': 'MAE', 'features': 2048, 'num_classes': 13}
MAE 2048 13
LATEST CHECKPOINT /home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/MAE/lightning_logs/version_594369/checkpoints/epoch=88-step=1691.ckpt
{'num_classes': 13}
