In [7]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
import math

from metaspace import SMInstance

import torch.nn.functional as functional
import torch
import torchvision.transforms as transforms
import lightning.pytorch as pl
import os
import pickle

from ionimage_embedding.dataloader.clr_dataloader import get_clr_dataloader
from ionimage_embedding.models.clr.cae import CAE
from ionimage_embedding.models.clr.clr_model import CLRmodel
from ionimage_embedding.models.clr.pseudo_labeling import run_knn, string_similarity_matrix

In [8]:
def download_data(evaluation_datasets, testing_dsid, training_dsid):
    training_results = {}
    training_images = {}
    training_if = {}
    polarity = '+'

    sm = SMInstance()

    for k in evaluation_datasets:
        ds = sm.dataset(id=k)
        results = ds.results(database=("HMDB", "v4"), fdr=0.2).reset_index()
        training_results[k] = results
        tmp = ds.all_annotation_images(fdr=0.2, database=("HMDB", "v4"), only_first_isotope=True)
        onsample = dict(zip(results['formula'].str.cat(results['adduct']), ~results['offSample']))
        formula = [x.formula+x.adduct for x in tmp if onsample[x.formula+x.adduct]]
        tmp = np.array([x._images[0] for x in tmp if onsample[x.formula+x.adduct]])
        training_images[k] = tmp
        training_if[k] = formula

    padding_images = size_adaption_symmetric(training_images)

    training_data = []
    training_datasets = [] 
    training_ions = []

    testing_data = []
    testing_datasets = [] 
    testing_ions = []


    for dsid, imgs in padding_images.items():

        if dsid in training_dsid:
            training_data.append(imgs)
            training_datasets += [dsid] * imgs.shape[0]
            training_ions += training_if[dsid]

        testing_data.append(imgs)
        testing_datasets += [dsid] * imgs.shape[0]
        testing_ions += training_if[dsid]


    training_data = np.concatenate(training_data)
    training_datasets = np.array(training_datasets)
    training_ions = np.array(training_ions)

    testing_data = np.concatenate(testing_data)
    testing_datasets = np.array(testing_datasets)
    testing_ions = np.array(testing_ions)
    
    return training_data, training_datasets, training_ions, testing_data, testing_datasets, testing_ions


def load_data(cache=False, cache_folder='/scratch/model_testing'):
    evaluation_datasets = [
    '2022-12-07_02h13m50s',
    '2022-12-07_02h13m20s',
    '2022-12-07_02h10m45s',
    '2022-12-07_02h09m41s',
    '2022-12-07_02h08m52s',
    '2022-12-07_01h02m53s',
    '2022-12-07_01h01m06s',
    '2022-11-28_22h24m25s',
    '2022-11-28_22h23m30s'
                  ]
    
    training_dsid = evaluation_datasets[:len(evaluation_datasets)-1]
    testing_dsid = evaluation_datasets[len(evaluation_datasets)-1]
    
    if cache:
        # make hash of datasets
        cache_file = 'clr_{}.pickle'.format(''.join(evaluation_datasets))
        
        # Check if cache folder exists
        if not os.path.isdir(cache_folder):
            os.mkdir(cache_folder)
        
        # Download data if it does not exist
        if cache_file not in os.listdir(cache_folder):
            data = download_data(evaluation_datasets, testing_dsid, training_dsid)
            pickle.dump(data, open(os.path.join(cache_folder, cache_file), "wb"))
            print('Saved file: {}'.format(os.path.join(cache_folder, cache_file)))
            return data        
        # Load cached data
        else:
            print('Loading cached data from: {}'.format(os.path.join(cache_folder, cache_file)))
            return pickle.load(open(os.path.join(cache_folder, cache_file), "rb" ) )
    
    else:
        return download_data(evaluation_datasets, testing_dsid, training_dsid)

In [9]:
training_data, training_datasets, training_ions, testing_data, testing_datasets, testing_ions = load_data(cache=True, cache_folder='/scratch/model_testing')

Loading cached data from: /scratch/model_testing/clr_2022-12-07_02h13m50s2022-12-07_02h13m20s2022-12-07_02h10m45s2022-12-07_02h09m41s2022-12-07_02h08m52s2022-12-07_01h02m53s2022-12-07_01h01m06s2022-11-28_22h24m25s2022-11-28_22h23m30s.pickle


In [12]:
 # Image normalization
sampleN = len(training_data)
val_data_fraction = .3

for i in range(0, sampleN):
    current_min = np.min(training_data[i, ::])
    current_max = np.max(training_data[i, ::])
    training_data[i, ::] = (training_data[i, ::] - current_min) / (current_max - current_min)


training_mask = np.arange(training_data.shape[0])
val_mask = np.random.randint(training_data.shape[0], size=math.floor(training_data.shape[0] * val_data_fraction))
training_mask = np.ones(len(training_data), bool)
training_mask[val_mask] = 0

In [13]:
ds_encoder = preprocessing.LabelEncoder()
dsl_int = torch.tensor(ds_encoder.fit_transform(training_datasets))
il_encoder = preprocessing.LabelEncoder()
ill_int = torch.tensor(il_encoder.fit_transform(training_ions))
height = training_data.shape[1]
width = training_data.shape[2]
sampleN = len(training_data)
batch_size=128

In [15]:
tdl = get_clr_dataloader(images=training_data[training_mask],
                   dataset_labels=dsl_int[training_mask],
                   ion_labels=ill_int[training_mask],
                   height=height,
                   width=width,
                   index=np.arange(training_data.shape[0])[training_mask],
                   # Rotate images
                   transform=transforms.RandomRotation(degrees=(0, 360)),
                   batch_size=batch_size)

In [18]:
vdl = get_clr_dataloader(images=training_data[val_mask],
                   dataset_labels=dsl_int[val_mask],
                   ion_labels=ill_int[val_mask],
                   height=height,
                   width=width,
                   index=np.arange(training_data.shape[0])[val_mask],
                   # Rotate images
                   transform=transforms.RandomRotation(degrees=(0, 360)),
                   batch_size=batch_size)

In [None]:
dl_image, dl_sample_id, dl_dataset_label, dl_ion_label = next(iter(tdl))

In [19]:
dl_image, dl_sample_id, dl_dataset_label, dl_ion_label = next(iter(vdl))

In [23]:
dl_dataset_label

tensor([[6],
        [1],
        [6],
        [5],
        [4],
        [6],
        [0],
        [5],
        [6],
        [1],
        [3],
        [5],
        [2],
        [6],
        [1],
        [3],
        [7],
        [3],
        [5],
        [4],
        [1],
        [6],
        [7],
        [0],
        [6],
        [4],
        [0],
        [3],
        [3],
        [7],
        [2],
        [4],
        [4],
        [1],
        [1],
        [5],
        [6],
        [5],
        [3],
        [7],
        [2],
        [0],
        [2],
        [2],
        [1],
        [6],
        [4],
        [1],
        [4],
        [6],
        [3],
        [4],
        [4],
        [7],
        [0],
        [2],
        [3],
        [0],
        [7],
        [3],
        [4],
        [4],
        [0],
        [3],
        [3],
        [4],
        [2],
        [2],
        [5],
        [3],
        [2],
        [5],
        [0],
        [0],
        [3],
        [0],
        [6],