In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torchvision import models

# Method 1: SimCLR

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch


class CustomScheduler: # Need this else NaNs
    def __init__(self, optimizer, warmup_epochs, initial_lr, final_lr, total_epochs):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.initial_lr = initial_lr
        self.final_lr = final_lr
        self.total_epochs = total_epochs
        self.after_warmup_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs - warmup_epochs)

    def step(self, epoch):
        if epoch < self.warmup_epochs:
            # Warm-up phase
            lr = self.initial_lr + (self.final_lr - self.initial_lr) * epoch / self.warmup_epochs
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        else:
            # After warm-up, use the cosine annealing schedule
            self.after_warmup_scheduler.step(epoch - self.warmup_epochs)


class simCLR(nn.Module): # the similarity loss of simCLR

    def __init__(self, encoder, device,batch_size,epochs):
        super().__init__()
        self.model = encoder.to(device) # define the encoder here
        self.criterion = torch.nn.CrossEntropyLoss().to(device)
        self.batch_size = batch_size
        self.epochs = epochs
        self.device = device
        self.optimizer          = torch.optim.AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.95), weight_decay=0.05)
        #self.scheduler          = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=20, eta_min=1e-8)#, last_epoch=-1)
        self.scheduler          = CustomScheduler(self.optimizer, warmup_epochs=50, initial_lr=1e-5, final_lr=1e-3, total_epochs=150)



    def SimCLR_loss(self, features):
        n_views = 2
        # Define the similarity matrix's pattern (which elements are related in the batch and which aren't)
        #original_tensor = torch.arange(0,self.batch_size,1)
        # Create the repeated pattern
        #pattern = torch.repeat_interleave(original_tensor, repeats=n_views)
        pattern = torch.cat([torch.arange(self.batch_size) for i in range(n_views)], dim=0)
        # make similarity matrix by the above method (need to understand this)
        pattern = (pattern.unsqueeze(0) == pattern.unsqueeze(1)).float()
        pattern = pattern.to(self.device)
        mask = torch.eye(pattern.shape[0])
        mask = mask.to(self.device)
        pattern = pattern-mask

        features = F.normalize(features,dim=1)
        similarity_matrix = torch.matmul(features, features.T)
        similarity_matrix = similarity_matrix - mask

        # select and combine positives
        positives = similarity_matrix[pattern.bool()].view(pattern.shape[0],-1)
        negatives = similarity_matrix[~pattern.bool()].view(similarity_matrix.shape[0],-1)

        logits = torch.cat([positives,negatives],dim=1)
        # we have to further develop the logits function as follows: we will define a temperature argument that sets the shape of the distribution
        temperature = 0.07
        logits = logits / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)

        return logits, labels

    def get_encoder(self):
        return self.model

    def save_checkpoint(self, file_path):
        """
        Save the model checkpoint.
        """
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }
        torch.save(checkpoint, file_path)
        print(f"Checkpoint saved to {file_path}")

    def load_checkpoint(self, file_path, device):
        """
        Load the model from the checkpoint.
        """
        checkpoint = torch.load(file_path, map_location=device)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Checkpoint loaded from {file_path}")



    def train(self, dataloader):
        self.losses = []  # Track losses

        # Start training
        for epoch in range(self.epochs):
            # Initialize tqdm progress bar
            train_loader = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{self.epochs}")

            for views, _ in train_loader:  # Unpack data and labels from each batch
                imgs = torch.cat(views, dim=0).to(self.device)

                # Load images and calculate InfoNCE loss
                features = self.model(imgs)
                logits, labels = self.SimCLR_loss(features)
                loss = self.criterion(logits, labels)

                # Append the loss
                self.losses.append(loss.item())

                # Perform optimization
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Update tqdm progress bar with the current loss
                train_loader.set_postfix(loss=loss.item())

            if (int(epoch)%10 == 0):
              file_path = '/content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth'

              # Save the current state of the model and optimizer
              self.save_checkpoint(file_path)


            #self.scheduler.step(epoch)


        return self.losses

In [4]:
from torchvision import transforms


class ContrastiveTransformations(object):
    def __init__(self, size=32, nviews=2, **kwargs):
        self.size = size
        self.nviews = nviews
        self.horizontal_flip = kwargs.get('horizontal_flip', True)
        self.resized_crop = kwargs.get('resized_crop', True)
        self.color_jitter = kwargs.get('color_jitter', True)
        self.random_grayscale = kwargs.get('random_grayscale', True)
        self.to_tensor = kwargs.get('to_tensor', True)
        self.normalize = kwargs.get('normalize', True)
        self.brightness = kwargs.get('brightness', 0.5)
        self.contrast = kwargs.get('contrast', 0.5)
        self.saturation = kwargs.get('saturation', 0.5)
        self.hue = kwargs.get('hue', 0.1)
        self.color_jitter_p = kwargs.get('color_jitter_p', 0.8)
        self.grayscale_p = kwargs.get('grayscale_p', 0.2)
        self.mean = kwargs.get('mean', (0.5,))
        self.std = kwargs.get('std', (0.5,))

    def __call__(self, x):
        transforms_list = []
        if self.horizontal_flip:
            transforms_list.append(transforms.RandomHorizontalFlip())
        if self.resized_crop:
            transforms_list.append(transforms.RandomResizedCrop(self.size))
        if self.color_jitter:
            color_jitter_transform = transforms.ColorJitter(
                brightness=self.brightness,
                contrast=self.contrast,
                saturation=self.saturation,
                hue=self.hue
            )
            transforms_list.append(transforms.RandomApply([color_jitter_transform], p=self.color_jitter_p))
        if self.random_grayscale:
            transforms_list.append(transforms.RandomGrayscale(p=self.grayscale_p))
        if self.to_tensor:
            transforms_list.append(transforms.ToTensor())
        if self.normalize:
            transforms_list.append(transforms.Normalize(self.mean, self.std))

        composed_transforms = transforms.Compose(transforms_list)
        return [composed_transforms(x) for _ in range(self.nviews)]




In [7]:
# Encoder 1: ResNet

class ConvBlock(nn.Module):
    def __init__(self,InputChannel,OutputChannel,Kernel,Padding,Stride):
        super().__init__()
        self.input_channel      = InputChannel
        self.kernel             = Kernel
        self.padding            = Padding
        self.stride             = Stride
        self.output_channel     = OutputChannel
        self.convblock          = nn.Sequential(
                nn.Conv2d(
                        in_channels     =self.input_channel,
                        out_channels    =self.output_channel,
                        kernel_size     =self.kernel,
                        stride          =self.stride,
                        padding         =self.padding,
                        bias            =False),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=self.kernel,stride=self.padding,padding=self.padding)
        )

    def forward(self,x):
        x = self.convblock(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride = 1, downsample = None):
        super().__init__() # inherit properties of nn.Module
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=3,stride=stride,padding=1),
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU()
                                   )
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=out_channels,out_channels=out_channels,stride=1,padding=1,kernel_size=3),
                                   nn.BatchNorm2d(out_channels)
                                   )
        self.out_channels = out_channels
        self.nonlinear = nn.ReLU()
        self.downsample = downsample

    def forward(self,x):
        residual = x #save residual in separate variable
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual # add the residual connections
        out = self.nonlinear(out)
        return out


class ResNet(nn.Module):
    def __init__(self,block,blocks,outputchannels):
        super().__init__()
        # add initial convolutional layer
        self.convlayer  = ConvBlock(InputChannel=3,OutputChannel=64,Kernel=3,Padding=1,Stride=1)
        # add the residual blocks
        self.layer1     = self._make_layer(block,inchannels=64,channels=128,numblocks = blocks[0],stride=2)
        self.layer2     = self._make_layer(block,inchannels=128,channels=256,numblocks = blocks[1],stride=2)
        self.layer3     = self._make_layer(block,inchannels=256,channels=512,numblocks = blocks[2],stride=2)
        #self.layer4     = self._make_layer(block,inchannels=64,channels=128,layers[0],stride=1)
        # add the average pooling block
        self.avgpooling = nn.AdaptiveAvgPool2d((1,1)) # compresses the above to 512,1,1 output size by averaging over the other dimensions
        self.fc         = nn.Linear(512,outputchannels)

    def _make_layer(self,block,inchannels,channels,numblocks,stride=1):
        # first define whether a downsample is needed:
        downsample = None
        if stride != 1 or inchannels != channels:
            downsample  = nn.Sequential(
                nn.Conv2d(in_channels=inchannels,out_channels = channels, kernel_size=1,stride = stride),
                nn.BatchNorm2d(channels)
            )
        layers = []
        layers.append(block(inchannels,channels,stride,downsample))
        for _ in range(1,numblocks): # loop over number of blocks
            layers.append(block(channels,channels))

        return nn.Sequential(*layers) # * operator is used to unpack the elements of an iterable (layers)

    def forward(self,x):
        x = self.convlayer(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpooling(x)
        x = torch.flatten(x,1) # convert to 1X1 vector
        x = self.fc(x)
        return x



In [5]:


# Encoder 2: ViT

class PatchEmbed(nn.Module):
    # converts image into patch embeddings based on total number of non-overlapping crops.
    # For each image containing n patches, there should be n embedding vectors per image, so a n x embedding_vector matrix.
    def __init__(self,img_size,patch_size,in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size       = img_size
        self.patch_size     = patch_size
        self.in_channels    = in_channels
        self.n_patches      = (img_size // patch_size)**2
        self.project        = nn.Conv2d(
                                    in_channels     =in_channels,
                                    out_channels    = embed_dim,
                                    kernel_size     = patch_size,
                                    stride          = patch_size,
                                    )

    def forward(self,x):
        # x has input a tensor of shape B, C, H, W (batch, channel, height, width)

        x = self.project(x)     # Batch X Embedding Dim X sqrt(N_patches) X sqrt(N_patches)
        x = x.flatten(2)        # Batch X Embedding Dim X N_patches
        x = x.transpose(1,2)    # Batch X N_patches X Embedding Dim

        return x


class Attention(nn.Module):

    def __init__(self, embed_dim, n_heads, qkv_bias = False, attn_dropout = 0., projection_dropout=0.):
        super().__init__()
        self.embed_dim          = embed_dim
        self.n_heads            = n_heads
        self.head_dim           = embed_dim // n_heads
        self.scale              = self.head_dim ** -0.5 # From vaswani paper
        self.qkv                = nn.Linear(embed_dim, 3* embed_dim) # convert input to query, key and value
        self.project            = nn.Linear(embed_dim,embed_dim)
        self.project_dropout    = nn.Dropout(projection_dropout)
        self.attention_dropout  = nn.Dropout(attn_dropout)

    def forward(self,x):

        batches, tokens, embed_dim = x.shape # tokens = total patches plus 1 class token

        QueryKeyValue = self.qkv(x) # it is like a neural form of repmat function.
        QueryKeyValue = QueryKeyValue.reshape(batches, tokens, 3, self.n_heads,self.head_dim)
        # Above has following dim: batches, tokens, [Query  Key Value], num_heads, head_dim
        QueryKeyValue = QueryKeyValue.permute(      2,      0, 3,             1,           4)
        # Above has following dim: QKV, batches, num_heads, tokens, head_dim
        Query, Key, Value    = QueryKeyValue[0], QueryKeyValue[1], QueryKeyValue[2]
        # Above has following dim: batches, num_heads, tokens, head_dim
        Attn_dot_product     = (Query @ Key.transpose(-2, -1)) * self.scale
        # Above has following dim: batches, num_heads, tokens, tokens
        Attention_mechanism  = Attn_dot_product.softmax(dim=-1)
        # Above has following dim: batches, num_heads, tokens, tokens
        Attention_mechanism  = self.attention_dropout(Attention_mechanism)
        # Applying the mask (from Values)
        Masking_mechanism    = (Attention_mechanism @ Value).transpose(1,2)
        # Above has following dim: batches, tokens, num_heads, head_dimension
        Masking_mechanism    = Masking_mechanism.flatten(2)
        # Above has following dim: batches, tokens, (num_heads*head_dimension), or, batches, tokens, embedding_dim
        Projection_operation = self.project(Masking_mechanism)
        Projection_operation = self.project_dropout(Projection_operation)

        return Projection_operation


class MultiLayerPerceptron(nn.Module):

    def __init__(self,in_features,hidden_features,out_features,dropout=0.):
        super().__init__()
        self.fc1            = nn.Linear(in_features,hidden_features)
        self.fc2            = nn.Linear(hidden_features,out_features)
        self.dropout        = nn.Dropout(dropout)
        self.activation     = nn.GELU()

    def forward(self,x): # x :: batches, tokens, in features
        x = self.fc1(x) # x :: batches, tokens, hidden features
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x) # x :: batches, tokens, out features
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):

    def __init__(self, embedding_dim, num_heads, MLP_ratio=4.0, qkv_bias = True, attention_dropout=0., projection_dropout=0.):
        super().__init__()
        self.norm1      = nn.LayerNorm(embedding_dim,eps=1e-6)
        self.norm2      = nn.LayerNorm(embedding_dim,eps=1e-6)
        self.attention  = Attention(embedding_dim,num_heads,qkv_bias,attention_dropout,projection_dropout)
        hidden_features = int(MLP_ratio * embedding_dim)
        self.mlp        = MultiLayerPerceptron(embedding_dim, hidden_features, embedding_dim, projection_dropout)

    def forward(self,x):
        x = x + self.attention(x)
        x = x + self.mlp(x)
        return x

class ViT_encoder(nn.Module):

    def __init__(self,
                 image_size,            # image_size (int)            : size of the input image
                 patch_size,            # patch_size (int)            : size of the patches to be extracted from the input image
                 in_channels,           # in_channels (int)           : number of input channels
                 embedding_dim,         # embedding_dim (int)         : number of elements of the embedding vector (per patch)
                 feature_size,          # feature_size (int)          : Total size of feature vector
                 n_blocks,              # n_blocks (int)              : total number of sequential transformer blocks (a.k.a. depth)
                 n_heads,               # n_heads (int)               : total number of attention heads per transformer block
                 mlp_ratio,             # mlp_ratio (float)           : the ratio by which embedding dimension expands inside a transformer block (in the MLP layer after attention)
                 qkv_bias,              # qkv_bias (bool)             : whether to add a bias term to the qkv projection layer or not
                 attention_dropout,     # attention_dropout (float)   : dropout in the attention layer
                 projection_dropout     # projection_dropout (float)  : dropout in the projection layer
                 ):
        super().__init__()
        self.patch_embedding    = PatchEmbed(
                                            img_size        =   image_size,
                                            patch_size      =   patch_size,
                                            in_channels     =   in_channels,
                                            embed_dim       =   embedding_dim
                                            )

        self.class_token        = nn.Parameter(torch.zeros(1,1,embedding_dim))
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches + 1, embedding_dim))
        self.position_dropout   = nn.Dropout(p = projection_dropout)
        self.blocks             = nn.ModuleList(
                                        [
                                            TransformerBlock(
                                                            embedding_dim       = embedding_dim,
                                                            num_heads           = n_heads,
                                                            MLP_ratio           = mlp_ratio,
                                                            qkv_bias            = qkv_bias,
                                                            attention_dropout   = attention_dropout,
                                                            projection_dropout  = projection_dropout,
                                                             )
                                         for _ in range(n_blocks)]
                                        )
        self.norm               = nn.LayerNorm(embedding_dim, eps=1e-6)
        self.head               = nn.Linear(embedding_dim, feature_size)

    def forward(self,x):

        batches             = x.shape[0] # total samples per batch
        x                   = self.patch_embedding(x) # convert images to patch embedding
        class_token         = self.class_token.expand(batches, -1, -1) #
        x                   = torch.cat((class_token,x), dim=1) # class token is not appended to the patch tokens
        x                   = x + self.position_embedding#(x) # Add the position embedding mechanism
        x                   = self.position_dropout(x)
        for block in self.blocks:
            x = block(x)
        x                   = self.norm(x) # add the layer norm mechanism now, giving us n_samples X (class token + patch token) X embedding dim
        x                   = x[:, 1:, :].mean(dim=1)  # global pool without cls token, giving us n_samples X embedding_dim
        # the 1: is done in the second dim because the first entry there is the class token, which we do not need (why do we have it then? lol...)
        x                   = self.head(x) # expand feature set to intended feature size
        return x


In [6]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torch

BATCH_SIZE = 2500
EPOCHS = 150


contrast_transforms = ContrastiveTransformations(
        nviews          =   2,
        horizontal_flip =   True,
        resized_crop    =   True,
        color_jitter    =   True,
        random_grayscale=   True,
        brightness      =   0.5,
        contrast        =   0.5,
        saturation      =   0.5,
        hue             =   0.1,
        color_jitter_p  =   0.8,
        grayscale_p     =   0.2,
        to_tensor       =   True,
        normalize       =   True,
        mean            =   (0.5,),
        std             =   (0.5,)
    )

# Initialize dataset and dataloader
cifar_trainset  = CIFAR10(root='./data',train=True,download=True, transform=contrast_transforms)
train_loader    = DataLoader(cifar_trainset, batch_size=BATCH_SIZE, shuffle=True)


device   = torch.device("cuda")

def resnet34():
    layers   = [3, 5, 7, 5]
    model    = ResNet(ResidualBlock, layers,1000)
    return model

def ViTencoder():
        model = ViT_encoder(
                    image_size          =   32,
                    patch_size          =   16,
                    in_channels         =   3,
                    embedding_dim       =   512,
                    feature_size        =   1000,
                    n_blocks            =   12,
                    n_heads             =   8,
                    mlp_ratio           =   4.0,
                    qkv_bias            =   True,
                    attention_dropout   =   0.2,
                    projection_dropout  =   0.2)
        return model

encoder  = ViTencoder() #resnet34()

simclr_model    = simCLR(
                            encoder         = ViTencoder(),#resnet34(),
                            device          = device,
                            batch_size      = BATCH_SIZE,
                            epochs          = EPOCHS)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29488513.10it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [14]:

simclr_loss_iter  =   simclr_model.train(train_loader)



Epoch 1/50: 100%|██████████| 20/20 [02:25<00:00,  7.25s/it, loss=6.58]


Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth


Epoch 2/50: 100%|██████████| 20/20 [02:24<00:00,  7.24s/it, loss=6.63]
Epoch 3/50: 100%|██████████| 20/20 [02:23<00:00,  7.20s/it, loss=6.62]
Epoch 4/50: 100%|██████████| 20/20 [02:24<00:00,  7.21s/it, loss=6.61]
Epoch 5/50: 100%|██████████| 20/20 [02:23<00:00,  7.20s/it, loss=6.56]
Epoch 6/50: 100%|██████████| 20/20 [02:24<00:00,  7.24s/it, loss=6.64]
Epoch 7/50: 100%|██████████| 20/20 [02:23<00:00,  7.16s/it, loss=6.54]
Epoch 8/50: 100%|██████████| 20/20 [02:23<00:00,  7.17s/it, loss=6.54]
Epoch 9/50: 100%|██████████| 20/20 [02:23<00:00,  7.15s/it, loss=6.58]
Epoch 10/50: 100%|██████████| 20/20 [02:22<00:00,  7.15s/it, loss=6.58]
Epoch 11/50: 100%|██████████| 20/20 [02:23<00:00,  7.17s/it, loss=6.56]


Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth


Epoch 12/50: 100%|██████████| 20/20 [02:24<00:00,  7.23s/it, loss=6.54]
Epoch 13/50: 100%|██████████| 20/20 [02:23<00:00,  7.19s/it, loss=6.58]
Epoch 14/50: 100%|██████████| 20/20 [02:23<00:00,  7.17s/it, loss=6.49]
Epoch 15/50: 100%|██████████| 20/20 [02:23<00:00,  7.17s/it, loss=6.51]
Epoch 16/50: 100%|██████████| 20/20 [02:23<00:00,  7.16s/it, loss=6.58]
Epoch 17/50: 100%|██████████| 20/20 [02:25<00:00,  7.27s/it, loss=6.53]
Epoch 18/50: 100%|██████████| 20/20 [02:26<00:00,  7.33s/it, loss=6.49]
Epoch 19/50: 100%|██████████| 20/20 [02:26<00:00,  7.30s/it, loss=6.6]
Epoch 20/50: 100%|██████████| 20/20 [02:26<00:00,  7.32s/it, loss=6.58]
Epoch 21/50: 100%|██████████| 20/20 [02:27<00:00,  7.35s/it, loss=6.47]


Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth


Epoch 22/50: 100%|██████████| 20/20 [02:26<00:00,  7.31s/it, loss=6.52]
Epoch 23/50: 100%|██████████| 20/20 [02:25<00:00,  7.29s/it, loss=6.4]
Epoch 24/50: 100%|██████████| 20/20 [02:27<00:00,  7.40s/it, loss=6.51]
Epoch 25/50: 100%|██████████| 20/20 [02:30<00:00,  7.53s/it, loss=6.53]
Epoch 26/50: 100%|██████████| 20/20 [02:29<00:00,  7.48s/it, loss=6.47]
Epoch 27/50: 100%|██████████| 20/20 [02:30<00:00,  7.52s/it, loss=6.46]
Epoch 28/50: 100%|██████████| 20/20 [02:27<00:00,  7.40s/it, loss=6.5]
Epoch 29/50: 100%|██████████| 20/20 [02:28<00:00,  7.45s/it, loss=6.45]
Epoch 30/50: 100%|██████████| 20/20 [02:27<00:00,  7.36s/it, loss=6.54]
Epoch 31/50: 100%|██████████| 20/20 [02:26<00:00,  7.32s/it, loss=6.44]


Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth


Epoch 32/50: 100%|██████████| 20/20 [02:27<00:00,  7.37s/it, loss=6.4]
Epoch 33/50: 100%|██████████| 20/20 [02:25<00:00,  7.28s/it, loss=6.44]
Epoch 34/50: 100%|██████████| 20/20 [02:25<00:00,  7.29s/it, loss=6.52]
Epoch 35/50: 100%|██████████| 20/20 [02:24<00:00,  7.22s/it, loss=6.41]
Epoch 36/50: 100%|██████████| 20/20 [02:24<00:00,  7.22s/it, loss=6.44]
Epoch 37/50: 100%|██████████| 20/20 [02:28<00:00,  7.41s/it, loss=6.31]
Epoch 38/50: 100%|██████████| 20/20 [02:26<00:00,  7.34s/it, loss=6.44]
Epoch 39/50: 100%|██████████| 20/20 [02:28<00:00,  7.44s/it, loss=6.44]
Epoch 40/50: 100%|██████████| 20/20 [02:28<00:00,  7.41s/it, loss=6.5]
Epoch 41/50: 100%|██████████| 20/20 [02:29<00:00,  7.46s/it, loss=6.39]


Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint.pth


Epoch 42/50: 100%|██████████| 20/20 [02:27<00:00,  7.40s/it, loss=6.34]
Epoch 43/50: 100%|██████████| 20/20 [02:25<00:00,  7.29s/it, loss=6.37]
Epoch 44/50: 100%|██████████| 20/20 [02:23<00:00,  7.20s/it, loss=6.44]
Epoch 45/50: 100%|██████████| 20/20 [02:24<00:00,  7.20s/it, loss=6.46]
Epoch 46/50: 100%|██████████| 20/20 [02:28<00:00,  7.42s/it, loss=6.37]
Epoch 47/50: 100%|██████████| 20/20 [02:26<00:00,  7.30s/it, loss=6.4]
Epoch 48/50: 100%|██████████| 20/20 [02:25<00:00,  7.29s/it, loss=6.35]
Epoch 49/50: 100%|██████████| 20/20 [02:28<00:00,  7.42s/it, loss=6.37]
Epoch 50/50: 100%|██████████| 20/20 [02:28<00:00,  7.43s/it, loss=6.37]


In [15]:
# Do so model saving and checkpoint saving juuuuuust in case colab fucks up

file_path = '/content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint3.pth'

# Save the current state of the model and optimizer
simclr_model.save_checkpoint(file_path)

Checkpoint saved to /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint3.pth


In [8]:


file_path = '/content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint3.pth'


simclr_model.load_checkpoint(file_path, device)






Checkpoint loaded from /content/drive/MyDrive/SimCLR_UMAP/simclr_checkpoint3.pth


In [9]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from scipy import optimize
from sklearn.manifold import SpectralEmbedding

class UMAPRepresentation:
    def __init__(self, features, n_neighbors=15, min_dist=0.5, n_components=2, learning_rate=1, max_iter=200):
        self.features = features
        self.n_neighbors = n_neighbors
        self.min_dist = min_dist
        self.n_components = n_components
        self.learning_rate = learning_rate
        self.max_iter = max_iter
        self.calculate_distances()
        self.initialize_spectral_embedding()
        self.calculate_probabilities()

    def calculate_distances(self):
        self.distances = np.square(euclidean_distances(self.features, self.features))
        self.rho = [sorted(self.distances[i])[1] for i in range(self.distances.shape[0])]

    def initialize_spectral_embedding(self):
        model = SpectralEmbedding(n_components=self.n_components, n_neighbors=50)
        self.specembed = model.fit_transform(self.features)


    def plot_initial_embedding(self):
        plt.figure(figsize=(5, 5))
        plt.scatter(self.specembed[:, 0], self.specembed[:, 1], c=self.Y_train.astype(int), cmap='tab10', s=50)
        plt.title('Laplacian Eigenmap', fontsize=14)
        plt.xlabel("LAP1", fontsize=12)
        plt.ylabel("LAP2", fontsize=12)
        plt.show()

    def prob_high_dim_umap(self, sigma, dist_row):
        d = self.distances[dist_row] - self.rho[dist_row]
        d[d < 0] = 0
        return np.exp(-d / sigma)

    def func(self, x, min_dist):
        y = []
        for i in range(len(x)):
            if x[i] <= min_dist:
                y.append(1)
            else:
                y.append(np.exp(-x[i] + min_dist))
        return y

    def k(self, prob):
        return np.power(2, np.sum(prob))

    def sigma_binary_search(self, k_of_sigma, fixed_k):
        sigma_low, sigma_high = 0, 1000
        for i in range(20):
            sigma_mid = (sigma_low + sigma_high) / 2
            if k_of_sigma(sigma_mid) < fixed_k:
                sigma_low = sigma_mid
            else:
                sigma_high = sigma_mid
            if np.abs(fixed_k - k_of_sigma(sigma_mid)) <= 1e-5:
                break
        return sigma_mid

    def calculate_probabilities(self):
        n = self.features.shape[0]
        self.prob = np.zeros((n, n))
        self.sigma_array = []
        for dist_row in range(n):
            func = lambda sigma: self.k(self.prob_high_dim_umap(sigma, dist_row))
            bin_search_result = self.sigma_binary_search(func, self.n_neighbors)
            self.prob[dist_row] = self.prob_high_dim_umap(bin_search_result, dist_row)
            self.sigma_array.append(bin_search_result)
        self.P = (self.prob + np.transpose(self.prob)) / 2

    def dist_low_dim(self, x, a, b):
        return 1 / (1 + a * x ** (2 * b))

    def prob_low_dim_umap(self, Y):
        euclid_distances = euclidean_distances(Y, Y)
        Q = (1 + self.a * euclid_distances ** (2 * self.b))
        return np.power(Q, -1)

    def CE(self, P, Y):
        Q = self.prob_low_dim_umap(Y)
        CE_term1 = -P * np.log(Q + 0.01)
        CE_term2 = - (1 - P) * np.log(1 - Q + 0.01)
        return CE_term1 + CE_term2

    def CE_gradient(self, P, Y):
        y_diff = np.expand_dims(Y, 1) - np.expand_dims(Y, 0)
        inv_dist = np.power((1 + self.a * euclidean_distances(Y, Y) ** (2 * self.b)), -1)
        Q = np.dot((1 - P), np.power((1 + self.a * euclidean_distances(Y, Y) ** (2 * self.b)), -1))
        np.fill_diagonal(Q, 0)
        Q = Q / np.sum(Q, axis=1, keepdims=True)
        fact = np.expand_dims(self.a * P * (1e-8 + np.square(euclidean_distances(Y, Y))) ** (self.b - 1) - Q, 2)
        return 2 * self.b * np.sum(fact * y_diff * np.expand_dims(inv_dist, 2), axis=1)

    def fit(self):
        x = np.linspace(0, 3, 100)
        p, _ = optimize.curve_fit(self.dist_low_dim, x, self.func(x, self.min_dist))
        self.a, self.b = p[0], p[1]
        np.random.seed(12345)
        #Y = np.random.normal(loc=0, scale=1, size=(self.features.shape[0], self.n_components)) # randomized initialization
        model = SpectralEmbedding(n_components=self.n_components, n_neighbors=50) # Laplacian initialization
        Y = model.fit_transform(self.features)
        self.CE_array = []

        # Integrating tqdm progress bar
        for i in tqdm(range(self.max_iter), desc="Fitting UMAP"):
            Y -= self.learning_rate * self.CE_gradient(self.P, Y)
            CE_current = np.sum(self.CE(self.P, Y)) / 1e+5
            self.CE_array.append(CE_current)

            # Updating the progress bar with the current CE loss
            tqdm.write(f"Iteration {i+1}/{self.max_iter}, CE Loss: {CE_current}")

        return Y

    # Optionally, you can add a method to plot the final UMAP embeddings
    def plot_umap(self, embeddings, title='UMAP Representation'):
        plt.scatter(embeddings[:, 0], embeddings[:, 1], s=50, cmap='tab10')
        plt.title(title, fontsize=14)
        plt.xlabel("UMAP1", fontsize=12)
        plt.ylabel("UMAP2", fontsize=12)
        plt.show()


In [10]:
from tqdm import tqdm
from torchvision import datasets

model = simclr_model.get_encoder()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


def get_representations(model, dataloader, device):
    model.eval()
    representations = []
    with torch.no_grad():
        for images, _ in dataloader:  # Removed tqdm for debugging
            images = images.to(device)
            features = model(images)  # Directly use model as it is the encoder
            representations.append(features.cpu().numpy())
    return np.concatenate(representations, axis=0)

# Testing the modified function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
representations = get_representations(model, testloader, device)


Files already downloaded and verified


In [12]:
print(representations.shape)

(10000, 1000)


In [None]:
# Apply UMAP
umap_representation = UMAPRepresentation(representations, n_neighbors=15, min_dist=0.5, n_components=2, learning_rate=1, max_iter=200)
umap_embeddings = umap_representation.fit()

# Extract class labels from CIFAR10
class_labels = []
for _, labels in testloader:
    class_labels.extend(labels.tolist())

# Convert to a NumPy array for easier handling
import numpy as np
class_labels = np.array(class_labels)

# Plotting the UMAP embeddings
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=class_labels, cmap='tab10', s=50)
plt.title('UMAP Representation of CIFAR10 Test Set', fontsize=14)
plt.xlabel("UMAP1", fontsize=12)
plt.ylabel("UMAP2", fontsize=12)
plt.colorbar()  # Add a color bar to show the correspondence to class labels
plt.show()

  return np.power(2, np.sum(prob))
Fitting UMAP:   0%|          | 1/200 [00:35<1:58:18, 35.67s/it]

Iteration 1/200, CE Loss: 4602.981807118203


Fitting UMAP:   1%|          | 2/200 [01:09<1:54:33, 34.72s/it]

Iteration 2/200, CE Loss: 4591.415876299777


Fitting UMAP:   2%|▏         | 3/200 [01:45<1:55:10, 35.08s/it]

Iteration 3/200, CE Loss: 4358.371576477033


Fitting UMAP:   2%|▏         | 4/200 [02:20<1:54:22, 35.01s/it]

Iteration 4/200, CE Loss: 3069.418462232056


Fitting UMAP:   2%|▎         | 5/200 [02:55<1:54:03, 35.10s/it]

Iteration 5/200, CE Loss: 1312.7768705117364


Fitting UMAP:   3%|▎         | 6/200 [03:31<1:54:06, 35.29s/it]

Iteration 6/200, CE Loss: 637.4388290040908


Fitting UMAP:   4%|▎         | 7/200 [04:05<1:52:23, 34.94s/it]

Iteration 7/200, CE Loss: 460.7260175670912


Fitting UMAP:   4%|▍         | 8/200 [04:41<1:52:36, 35.19s/it]

Iteration 8/200, CE Loss: 373.91624861773835


Fitting UMAP:   4%|▍         | 9/200 [05:15<1:51:08, 34.91s/it]

Iteration 9/200, CE Loss: 307.66937250433665


Fitting UMAP:   5%|▌         | 10/200 [05:50<1:51:02, 35.06s/it]

Iteration 10/200, CE Loss: 257.7870663451329


Fitting UMAP:   6%|▌         | 11/200 [06:25<1:50:40, 35.13s/it]

Iteration 11/200, CE Loss: 229.5498415722989


Fitting UMAP:   6%|▌         | 12/200 [07:00<1:49:44, 35.02s/it]

Iteration 12/200, CE Loss: 204.4592538953846


Fitting UMAP:   6%|▋         | 13/200 [07:36<1:49:46, 35.22s/it]

Iteration 13/200, CE Loss: 189.57204718740252


Fitting UMAP:   7%|▋         | 14/200 [08:10<1:48:10, 34.89s/it]

Iteration 14/200, CE Loss: 175.97358295559943


Fitting UMAP:   8%|▊         | 15/200 [08:46<1:48:22, 35.15s/it]

Iteration 15/200, CE Loss: 167.7735218245653


Fitting UMAP:   8%|▊         | 16/200 [09:21<1:47:30, 35.06s/it]

Iteration 16/200, CE Loss: 159.02077122006693


Fitting UMAP:   8%|▊         | 17/200 [09:57<1:47:54, 35.38s/it]

Iteration 17/200, CE Loss: 154.3935355047844


Fitting UMAP:   9%|▉         | 18/200 [10:32<1:47:26, 35.42s/it]

Iteration 18/200, CE Loss: 148.2892864624364


Fitting UMAP:  10%|▉         | 19/200 [11:06<1:45:39, 35.03s/it]

Iteration 19/200, CE Loss: 144.44309303085473


Fitting UMAP:  10%|█         | 20/200 [11:42<1:45:37, 35.21s/it]

Iteration 20/200, CE Loss: 140.17689987760315


Fitting UMAP:  10%|█         | 21/200 [12:16<1:44:10, 34.92s/it]

Iteration 21/200, CE Loss: 137.3415997927709


Fitting UMAP:  11%|█         | 22/200 [12:52<1:43:58, 35.05s/it]

Iteration 22/200, CE Loss: 133.83930042145533


Fitting UMAP:  12%|█▏        | 23/200 [13:27<1:43:32, 35.10s/it]

Iteration 23/200, CE Loss: 132.23619923440128


Fitting UMAP:  12%|█▏        | 24/200 [14:02<1:42:35, 34.98s/it]

Iteration 24/200, CE Loss: 129.3309512611054


Fitting UMAP:  12%|█▎        | 25/200 [14:37<1:42:24, 35.11s/it]

Iteration 25/200, CE Loss: 127.94206771077855


Fitting UMAP:  13%|█▎        | 26/200 [15:11<1:40:41, 34.72s/it]

Iteration 26/200, CE Loss: 125.33544841268143


Fitting UMAP:  14%|█▎        | 27/200 [15:46<1:40:51, 34.98s/it]

Iteration 27/200, CE Loss: 124.64956843631622


Fitting UMAP:  14%|█▍        | 28/200 [16:20<1:39:31, 34.72s/it]

Iteration 28/200, CE Loss: 121.97732495085934


Fitting UMAP:  14%|█▍        | 29/200 [16:56<1:39:24, 34.88s/it]

Iteration 29/200, CE Loss: 121.73673264923096


Fitting UMAP:  15%|█▌        | 30/200 [17:31<1:38:56, 34.92s/it]

Iteration 30/200, CE Loss: 119.77454465648287


Fitting UMAP:  16%|█▌        | 31/200 [18:05<1:38:07, 34.84s/it]

Iteration 31/200, CE Loss: 119.11796470542498


Fitting UMAP:  16%|█▌        | 32/200 [18:41<1:38:09, 35.06s/it]

Iteration 32/200, CE Loss: 117.37319566618679


Fitting UMAP:  16%|█▋        | 33/200 [19:15<1:36:36, 34.71s/it]

Iteration 33/200, CE Loss: 116.96764469383949


Fitting UMAP:  17%|█▋        | 34/200 [19:51<1:36:56, 35.04s/it]

Iteration 34/200, CE Loss: 115.59514989677264


Fitting UMAP:  18%|█▊        | 35/200 [20:25<1:35:50, 34.85s/it]

Iteration 35/200, CE Loss: 114.99021886107896


Fitting UMAP:  18%|█▊        | 36/200 [21:00<1:35:40, 35.00s/it]

Iteration 36/200, CE Loss: 113.82586155121058


Fitting UMAP:  18%|█▊        | 37/200 [21:36<1:35:15, 35.07s/it]

Iteration 37/200, CE Loss: 113.34527624874696


Fitting UMAP:  19%|█▉        | 38/200 [22:10<1:34:02, 34.83s/it]

Iteration 38/200, CE Loss: 112.47295680169701


Fitting UMAP:  20%|█▉        | 39/200 [22:46<1:34:08, 35.09s/it]

Iteration 39/200, CE Loss: 111.95033776710258


Fitting UMAP:  20%|██        | 40/200 [23:20<1:32:51, 34.82s/it]

Iteration 40/200, CE Loss: 111.36014764644517


Fitting UMAP:  20%|██        | 41/200 [23:57<1:34:09, 35.53s/it]

Iteration 41/200, CE Loss: 110.54812829404322


Fitting UMAP:  21%|██        | 42/200 [24:32<1:33:27, 35.49s/it]

Iteration 42/200, CE Loss: 110.32641120776988


Fitting UMAP:  22%|██▏       | 43/200 [25:07<1:32:01, 35.17s/it]

Iteration 43/200, CE Loss: 109.57558260294327


Fitting UMAP:  22%|██▏       | 44/200 [25:42<1:31:38, 35.25s/it]

Iteration 44/200, CE Loss: 109.24670748282628


Fitting UMAP:  22%|██▎       | 45/200 [26:16<1:30:02, 34.86s/it]

Iteration 45/200, CE Loss: 108.3159340676683


Fitting UMAP:  23%|██▎       | 46/200 [26:52<1:29:54, 35.03s/it]

Iteration 46/200, CE Loss: 108.27972316626962


Fitting UMAP:  24%|██▎       | 47/200 [27:26<1:29:02, 34.92s/it]

Iteration 47/200, CE Loss: 107.67202917741672


Fitting UMAP:  24%|██▍       | 48/200 [28:02<1:28:42, 35.02s/it]

Iteration 48/200, CE Loss: 107.46292778000937


Fitting UMAP:  24%|██▍       | 49/200 [28:37<1:28:24, 35.13s/it]

Iteration 49/200, CE Loss: 107.38944048507227


Fitting UMAP:  25%|██▌       | 50/200 [29:11<1:27:00, 34.80s/it]

Iteration 50/200, CE Loss: 106.4741594854296


Fitting UMAP:  26%|██▌       | 51/200 [29:47<1:27:00, 35.04s/it]

Iteration 51/200, CE Loss: 106.54523220180072


Fitting UMAP:  26%|██▌       | 52/200 [30:22<1:26:38, 35.13s/it]

Iteration 52/200, CE Loss: 105.98573353089435


Fitting UMAP:  26%|██▋       | 53/200 [30:57<1:26:19, 35.24s/it]

Iteration 53/200, CE Loss: 105.88456039696177


Fitting UMAP:  27%|██▋       | 54/200 [31:32<1:25:35, 35.18s/it]

Iteration 54/200, CE Loss: 105.32438755099868


Fitting UMAP:  28%|██▊       | 55/200 [32:07<1:24:41, 35.04s/it]

Iteration 55/200, CE Loss: 105.2605000915731


Fitting UMAP:  28%|██▊       | 56/200 [32:43<1:24:19, 35.14s/it]

Iteration 56/200, CE Loss: 105.09892737252247


Fitting UMAP:  28%|██▊       | 57/200 [33:16<1:22:53, 34.78s/it]

Iteration 57/200, CE Loss: 104.80188037314245


Fitting UMAP:  29%|██▉       | 58/200 [33:52<1:22:52, 35.02s/it]

Iteration 58/200, CE Loss: 104.651351906733


Fitting UMAP:  30%|██▉       | 59/200 [34:27<1:22:01, 34.90s/it]

Iteration 59/200, CE Loss: 104.33993553536978


Fitting UMAP:  30%|███       | 60/200 [35:02<1:21:49, 35.07s/it]

Iteration 60/200, CE Loss: 104.10217408223654


Fitting UMAP:  30%|███       | 61/200 [35:38<1:21:29, 35.17s/it]

Iteration 61/200, CE Loss: 103.80933193421151


Fitting UMAP:  31%|███       | 62/200 [36:12<1:20:13, 34.88s/it]

Iteration 62/200, CE Loss: 103.7607458621973


Fitting UMAP:  32%|███▏      | 63/200 [36:47<1:20:02, 35.05s/it]

Iteration 63/200, CE Loss: 103.2534982752778


Fitting UMAP:  32%|███▏      | 64/200 [37:21<1:18:39, 34.70s/it]

Iteration 64/200, CE Loss: 103.52676376221294


Fitting UMAP:  32%|███▎      | 65/200 [37:56<1:18:32, 34.90s/it]

Iteration 65/200, CE Loss: 102.90799420851236


Fitting UMAP:  33%|███▎      | 66/200 [38:31<1:17:47, 34.84s/it]

Iteration 66/200, CE Loss: 103.01385537156573


Fitting UMAP:  34%|███▎      | 67/200 [39:06<1:17:27, 34.94s/it]

Iteration 67/200, CE Loss: 102.90145592325455


Fitting UMAP:  34%|███▍      | 68/200 [39:42<1:17:19, 35.15s/it]

Iteration 68/200, CE Loss: 102.61345472242462


Fitting UMAP:  34%|███▍      | 69/200 [40:16<1:16:08, 34.87s/it]

Iteration 69/200, CE Loss: 102.61627860136062


Fitting UMAP:  35%|███▌      | 70/200 [40:51<1:15:48, 34.99s/it]

Iteration 70/200, CE Loss: 102.21691295091945


Fitting UMAP:  36%|███▌      | 71/200 [41:25<1:14:38, 34.71s/it]

Iteration 71/200, CE Loss: 102.25529405367679


Fitting UMAP:  36%|███▌      | 72/200 [42:01<1:14:33, 34.95s/it]

Iteration 72/200, CE Loss: 102.17023800109925


Fitting UMAP:  36%|███▋      | 73/200 [42:36<1:13:45, 34.84s/it]

Iteration 73/200, CE Loss: 101.97660972930284


Fitting UMAP:  37%|███▋      | 74/200 [43:11<1:13:22, 34.94s/it]

Iteration 74/200, CE Loss: 101.74438575293405


Fitting UMAP:  38%|███▊      | 75/200 [43:46<1:13:02, 35.06s/it]

Iteration 75/200, CE Loss: 101.90642016518072


Fitting UMAP:  38%|███▊      | 76/200 [44:20<1:11:54, 34.79s/it]

Iteration 76/200, CE Loss: 101.39811978523528


Fitting UMAP:  38%|███▊      | 77/200 [44:56<1:11:48, 35.03s/it]

Iteration 77/200, CE Loss: 101.66814829484768


Fitting UMAP:  39%|███▉      | 78/200 [45:30<1:10:37, 34.74s/it]

Iteration 78/200, CE Loss: 101.09942899567208


Fitting UMAP:  40%|███▉      | 79/200 [46:05<1:10:30, 34.96s/it]

Iteration 79/200, CE Loss: 101.38293329184543


Fitting UMAP:  40%|████      | 80/200 [46:40<1:09:53, 34.94s/it]

Iteration 80/200, CE Loss: 100.92510090794826


Fitting UMAP:  40%|████      | 81/200 [47:15<1:09:11, 34.89s/it]

Iteration 81/200, CE Loss: 101.11658958629822


Fitting UMAP:  41%|████      | 82/200 [47:51<1:08:57, 35.06s/it]

Iteration 82/200, CE Loss: 100.78956288134489


Fitting UMAP:  42%|████▏     | 83/200 [48:25<1:07:56, 34.84s/it]

Iteration 83/200, CE Loss: 100.70800653098276


Fitting UMAP:  42%|████▏     | 84/200 [49:00<1:07:42, 35.02s/it]

Iteration 84/200, CE Loss: 100.51756063811116


Fitting UMAP:  42%|████▎     | 85/200 [49:34<1:06:37, 34.76s/it]

Iteration 85/200, CE Loss: 100.74658400239237


Fitting UMAP:  43%|████▎     | 86/200 [50:10<1:06:28, 34.98s/it]

Iteration 86/200, CE Loss: 100.11602937795654


Fitting UMAP:  44%|████▎     | 87/200 [50:45<1:05:53, 34.98s/it]

Iteration 87/200, CE Loss: 100.36309987201851


Fitting UMAP:  44%|████▍     | 88/200 [51:20<1:05:24, 35.04s/it]

Iteration 88/200, CE Loss: 99.84975332530344


Fitting UMAP:  44%|████▍     | 89/200 [51:55<1:04:58, 35.12s/it]

Iteration 89/200, CE Loss: 100.43389411300494


Fitting UMAP:  45%|████▌     | 90/200 [52:30<1:03:50, 34.82s/it]

Iteration 90/200, CE Loss: 99.55732252615255


Fitting UMAP:  46%|████▌     | 91/200 [53:05<1:03:46, 35.11s/it]

Iteration 91/200, CE Loss: 100.00018965385789


Fitting UMAP:  46%|████▌     | 92/200 [53:39<1:02:40, 34.82s/it]

Iteration 92/200, CE Loss: 99.35112438982593


Fitting UMAP:  46%|████▋     | 93/200 [54:15<1:02:28, 35.04s/it]

Iteration 93/200, CE Loss: 99.92505313582639


Fitting UMAP:  47%|████▋     | 94/200 [54:50<1:01:41, 34.92s/it]

Iteration 94/200, CE Loss: 99.15022808358677


Fitting UMAP:  48%|████▊     | 95/200 [55:25<1:01:05, 34.91s/it]

Iteration 95/200, CE Loss: 99.55226298676392


Fitting UMAP:  48%|████▊     | 96/200 [56:00<1:00:39, 34.99s/it]

Iteration 96/200, CE Loss: 99.1449590293032


Fitting UMAP:  48%|████▊     | 97/200 [56:34<59:36, 34.72s/it]  

Iteration 97/200, CE Loss: 99.36092046156794


Fitting UMAP:  49%|████▉     | 98/200 [57:09<59:21, 34.92s/it]

Iteration 98/200, CE Loss: 98.72523239766505


Fitting UMAP:  50%|████▉     | 99/200 [57:43<58:18, 34.64s/it]

Iteration 99/200, CE Loss: 99.21739608780506


Fitting UMAP:  50%|█████     | 100/200 [58:19<58:09, 34.90s/it]

Iteration 100/200, CE Loss: 98.69238101743485


Fitting UMAP:  50%|█████     | 101/200 [58:53<57:23, 34.78s/it]

Iteration 101/200, CE Loss: 98.86900952112968


Fitting UMAP:  51%|█████     | 102/200 [59:28<56:54, 34.85s/it]

Iteration 102/200, CE Loss: 98.6140837133244


Fitting UMAP:  52%|█████▏    | 103/200 [1:00:04<56:39, 35.04s/it]

Iteration 103/200, CE Loss: 98.77540966754427


Fitting UMAP:  52%|█████▏    | 104/200 [1:00:38<55:38, 34.78s/it]

Iteration 104/200, CE Loss: 98.53538102705441


Fitting UMAP:  52%|█████▎    | 105/200 [1:01:13<55:22, 34.98s/it]

Iteration 105/200, CE Loss: 98.4972146517278


Fitting UMAP:  53%|█████▎    | 106/200 [1:01:47<54:22, 34.70s/it]

Iteration 106/200, CE Loss: 98.14185730705736


Fitting UMAP:  54%|█████▎    | 107/200 [1:02:23<54:09, 34.94s/it]

Iteration 107/200, CE Loss: 98.11079110782792


Fitting UMAP:  54%|█████▍    | 108/200 [1:02:58<53:40, 35.00s/it]

Iteration 108/200, CE Loss: 97.94797263760577


Fitting UMAP:  55%|█████▍    | 109/200 [1:03:33<52:57, 34.92s/it]

Iteration 109/200, CE Loss: 97.92592390576219


Fitting UMAP:  55%|█████▌    | 110/200 [1:04:08<52:39, 35.11s/it]

Iteration 110/200, CE Loss: 97.83728522273512


Fitting UMAP:  56%|█████▌    | 111/200 [1:04:42<51:36, 34.79s/it]

Iteration 111/200, CE Loss: 97.76454731536576


Fitting UMAP:  56%|█████▌    | 112/200 [1:05:18<51:23, 35.04s/it]

Iteration 112/200, CE Loss: 97.69056696792943


Fitting UMAP:  56%|█████▋    | 113/200 [1:05:52<50:28, 34.81s/it]

Iteration 113/200, CE Loss: 97.65541769665343


Fitting UMAP:  57%|█████▋    | 114/200 [1:06:28<50:07, 34.97s/it]

Iteration 114/200, CE Loss: 97.70731500803973


Fitting UMAP:  57%|█████▊    | 115/200 [1:07:02<49:29, 34.94s/it]

Iteration 115/200, CE Loss: 97.5351802004908


Fitting UMAP:  58%|█████▊    | 116/200 [1:07:37<48:51, 34.90s/it]

Iteration 116/200, CE Loss: 97.55963155519646


Fitting UMAP:  58%|█████▊    | 117/200 [1:08:13<48:33, 35.10s/it]

Iteration 117/200, CE Loss: 97.55374546040954


Fitting UMAP:  59%|█████▉    | 118/200 [1:08:47<47:28, 34.74s/it]

Iteration 118/200, CE Loss: 97.33867181325044


Fitting UMAP:  60%|█████▉    | 119/200 [1:09:22<47:08, 34.92s/it]

Iteration 119/200, CE Loss: 97.45902058770031


Fitting UMAP:  60%|██████    | 120/200 [1:09:56<46:17, 34.72s/it]

Iteration 120/200, CE Loss: 97.0908370024976


Fitting UMAP:  60%|██████    | 121/200 [1:10:32<45:57, 34.91s/it]

Iteration 121/200, CE Loss: 97.21994958956859


Fitting UMAP:  61%|██████    | 122/200 [1:11:06<45:16, 34.83s/it]

Iteration 122/200, CE Loss: 96.92323988610124


Fitting UMAP:  62%|██████▏   | 123/200 [1:11:41<44:46, 34.89s/it]

Iteration 123/200, CE Loss: 96.91125417463452


Fitting UMAP:  62%|██████▏   | 124/200 [1:12:17<44:24, 35.06s/it]

Iteration 124/200, CE Loss: 96.91174830451438


Fitting UMAP:  62%|██████▎   | 125/200 [1:12:51<43:28, 34.78s/it]

Iteration 125/200, CE Loss: 96.93939143606292


Fitting UMAP:  63%|██████▎   | 126/200 [1:13:27<43:20, 35.15s/it]

Iteration 126/200, CE Loss: 96.66559212037194


Fitting UMAP:  64%|██████▎   | 127/200 [1:14:02<42:42, 35.10s/it]

Iteration 127/200, CE Loss: 96.73333519986656


Fitting UMAP:  64%|██████▍   | 128/200 [1:14:37<42:10, 35.15s/it]

Iteration 128/200, CE Loss: 96.60956266559263


Fitting UMAP:  64%|██████▍   | 129/200 [1:15:13<41:52, 35.38s/it]

Iteration 129/200, CE Loss: 96.73741656565346


Fitting UMAP:  65%|██████▌   | 130/200 [1:15:47<40:54, 35.07s/it]

Iteration 130/200, CE Loss: 96.24234552164353


Fitting UMAP:  66%|██████▌   | 131/200 [1:16:23<40:27, 35.19s/it]

Iteration 131/200, CE Loss: 96.68548360796339


Fitting UMAP:  66%|██████▌   | 132/200 [1:16:57<39:34, 34.92s/it]

Iteration 132/200, CE Loss: 96.49786493320383


Fitting UMAP:  66%|██████▋   | 133/200 [1:17:33<39:13, 35.12s/it]

Iteration 133/200, CE Loss: 96.46995688596598


Fitting UMAP:  67%|██████▋   | 134/200 [1:18:08<38:35, 35.08s/it]

Iteration 134/200, CE Loss: 96.19937396874785


Fitting UMAP:  68%|██████▊   | 135/200 [1:18:42<37:49, 34.92s/it]

Iteration 135/200, CE Loss: 96.42240317890312


Fitting UMAP:  68%|██████▊   | 136/200 [1:19:18<37:31, 35.18s/it]

Iteration 136/200, CE Loss: 95.92788348949352


Fitting UMAP:  68%|██████▊   | 137/200 [1:19:52<36:36, 34.87s/it]

Iteration 137/200, CE Loss: 96.40851531723231


Fitting UMAP:  69%|██████▉   | 138/200 [1:20:28<36:16, 35.11s/it]

Iteration 138/200, CE Loss: 95.8120583773644
