In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch import nn
import torch.nn.functional as F
from SproutDataset import SproutDataset
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import SimSiamTransform
import torch
from Siamese_Architecture import Siamese1DNet_backbone, SimSiam

import matplotlib.pyplot as plt

BATCH_SIZE = 32

In [None]:
import tqdm
import umap.umap_ as umap
import numpy as np
import hdbscan
import sklearn.metrics as metrics
from SproutDataset import map_item_map

Self-supervised search: run infidence on the backbone+projection head

Feed the spectra thu the backbone, then backbone to proj head. Give 128 values. Encoding spectra to vector of size 128.
End up with vector for each one of the spectra. 

Take those 128 values for all spectra, and pass them to Umap. Turn 128 values to (x,y) coordinates.

In [None]:
dataset_path = "C:\\Users\\tania\\Documents\\SPICE\\SPROUTS\\spectra_train.nc"
dataset_path_mini = "C:\\Users\\tania\\Documents\\SPICE\\SPROUTS\\spectra_train_mini.nc"
datasetsingle = SproutDataset(dataset_path=dataset_path_mini, augmentation_type='single')
dataloader = DataLoader(
            datasetsingle,
            batch_size=BATCH_SIZE,
            shuffle=True)

In [None]:
for batch in dataloader:
    if isinstance(batch, list):  # If batch is a list
        batch = torch.stack(batch)  # Stack it into a tensor
    print(batch.shape)  # Should now show [32, 451]
    break

In [None]:
model = SimSiam()
dataset = datasetsingle

wandb_logger = WandbLogger(project="SimSiam_Training_128_contrastive", log_model=True)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"

In [None]:
trainer = pl.Trainer(max_epochs=5, devices=1, accelerator=accelerator, logger=wandb_logger)
trainer.fit(model=model, train_dataloaders=dataloader)

In [None]:
checkpoint = "C:\\Users\\tania\\Documents\\CU Boulder\\CU Fall 2024\\ASEN 6337\\Individual project\\SPICE_DeepLearning\\SimSiam_Training_128_contrastive\\50a3p00v\\checkpoints\epoch=4-step=18300.ckpt"
loaded_model = SimSiam.load_from_checkpoint(checkpoint)
loaded_model.eval()

In [None]:
dataset_none = SproutDataset(dataset_path=dataset_path, augmentation_type=None)
dataset_none_mini = SproutDataset(dataset_path=dataset_path_mini, augmentation_type=None)

In [None]:
outputs = []
with torch.no_grad():  # Disable gradient computation for inference
    for i in tqdm.tqdm(range (dataset_none_mini.__len__())):
        spec = dataset_none_mini.__getitem__(i).unsqueeze(0) 

        # Move tensor to the same device as the model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        loaded_model = loaded_model.to(device)
        spec = spec.to(device)

        outputs.append(loaded_model(spec)[0].cpu().numpy())

In [None]:
stacked_outputs = np.stack(outputs).squeeze()

In [None]:
stacked_outputs.shape

In [None]:
stacked_outputs_L2 = stacked_outputs / np.linalg.norm(stacked_outputs, ord=2)
clusterer = hdbscan.HDBSCAN(min_cluster_size=100, min_samples=15, metric='euclidean')
clusterer.fit(stacked_outputs_L2)

labels = clusterer.labels_

In [None]:
from sunraster.instr.spice import read_spice_l2_fits   
import xarray as xr

In [None]:
def map_item_map(item_nbr, dataset, plot=False, data_dir='C:\\Users\\tania\\Documents\\SPICE\\SPROUTS\\data_L2\\', key='Ne VIII 770 (Merged)',croplatbottom=725, croplattop=115):
    filename = str(dataset.isel(index=item_nbr)['filename'].data)
    i,j = (dataset.isel(index=item_nbr)['x-index'].data, dataset.isel(index=item_nbr)['y-index'].data)
    exposure = read_spice_l2_fits(data_dir+filename, memmap=False)
    cube = exposure[key][0,:,croplattop:croplatbottom,:].data
    if plot:
        plt.imshow(cube[20, :, :], aspect=1/4, cmap='gist_heat', vmax=np.nanquantile(cube[20, :, :], 0.999))
        print(cube.shape)
        plt.plot(j,i, color='red', marker='+')

datasetmini = xr.open_dataset("C:\\Users\\tania\\Documents\\SPICE\\SPROUTS\\spectra_train_mini.nc")
labels_1file = labels.reshape(610,192)
plt.figure(figsize=(15,6))
plt.subplot(211)
plt.imshow(labels_1file, cmap='Accent', aspect=1/4)
plt.colorbar()
plt.show()
map_item_map(80000, datasetmini, plot=True)

In [None]:
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
projected_data = reducer.fit_transform(stacked_outputs)

In [None]:
plt.figure(figsize=(18, 8))
plt.subplot(121)
scatter = plt.scatter(projected_data[:, 0], projected_data[:, 1], cmap='Spectral', s=1)
plt.title('UMAP Projection to 2D')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.subplot(122)
plt.hist2d(projected_data[:, 0], projected_data[:, 1], bins=50)
plt.grid(True)
plt.title('Density histogram')
plt.show()

In [None]:
cmap = plt.cm.Spectral
cmap.set_under('white')

plt.figure(figsize=(15,10))
plt.scatter(projected_data[:, 0], projected_data[:, 1], c=labels, cmap='Spectral', s=10, alpha=0.5)
cbar = plt.colorbar(label='Cluster')

plt.title('HDBSCAN Clustering - Contrastive loss\n2-dimensional output')
plt.show()

In [None]:
mask = labels != -1
filtered_data = projected_data[mask]
filtered_labels = labels[mask]

# Plot
plt.figure(figsize=(15, 10))
scatter = plt.scatter(
    filtered_data[:, 0],
    filtered_data[:, 1],
    c=filtered_labels,
    cmap='Spectral',
    s=10
)
plt.colorbar(scatter, label='Cluster')
plt.title('HDBSCAN Clustering - Contrastive Loss\n2-dimensional output (excluding -1)')
plt.show()