In [45]:
import torch
import os
import pandas as pd
from scipy.io import mmread
import numpy as np

# Create new training, test and validation 

In [None]:
# Get datasplit files
file_train = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/datasplit1-train.csv')
file_test = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/datasplit1-test.csv')
file_val = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/datasplit1-val.csv')

unseen_examples = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/unseen_examples.csv')

In [10]:
np.intersect1d(file_val.CPD_NAME, file_test.CPD_NAME)

array(['boldine', 'naproxen'], dtype=object)

In [11]:
# get the whole data
dataset = pd.concat([file_train, file_test, file_val], ignore_index=True)

In [13]:
# Check column names of the new dataset 
print(file_train.columns)

Index(['SAMPLE_KEY', 'BROAD_ID', 'PLATE_ID', 'WELL_POSITION', 'SITE',
       'SAMPLE_ID', 'CPD_NAME', 'CPD_NAME_TYPE', 'SMILES', 'INCHIKEY',
       'IMG_ERSyto', 'IMG_ERSytoBleed', 'IMG_Hoechst', 'IMG_Mito',
       'IMG_Ph_golgi', 'IMG_CNT_CELLS', 'ROW_NR_LABEL_MAT'],
      dtype='object')


In [14]:
# Sample 8.5K molecules that will be seen from the training set
molecules = np.unique(dataset.CPD_NAME)
len(np.unique(dataset.CPD_NAME))

10560

In [15]:
# Sample the seen molecules 
n = len(dataset)
train_perc, test_perc, valid_perc = np.round(n*0.80), np.round(n*0.20),  np.round(n*0.10)

# Each molecule counts either 48 or 24 entries 

In [16]:
np.random.seed(42)
unseen = list(np.unique(unseen_examples.CPD_NAME))
seen = list(set(dataset.CPD_NAME)-set(unseen)) 

In [17]:
# Check the seen and unseen do not collide
set(seen).intersection(unseen)

set()

In [18]:
seen_rows = [i for i in range(n) if dataset.CPD_NAME[i] in seen]
unseen_rows = [i for i in range(n) if dataset.CPD_NAME[i] not in seen]

In [27]:
np.savez('/home/icb/alessandro.palma/imCPA/data/metadata_processed/seen_unseen_compounds.npz', seen = seen, unseen = unseen) 

In [23]:
# Split training and validation set 
np.random.seed(42)
def split_train_val_test(df, seen, unseen):
    # Indices for training, test and validation indexes 
    train_set, test_set, valid_set, ood_set = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    # For each molecule shared
    for seen_mol in seen:
        # Slice of dataset for a molecule
        ds_slice = dataset.iloc[dataset.CPD_NAME.values == seen_mol]
        n_slice = len(ds_slice)
        n_train, n_test, n_valid = int(np.round(n_slice*0.70)), int(np.round(n_slice*0.20)), int(np.round(n_slice*0.10))
        # Separate the slice observations belonging to the three sets (randomly)
        df_slice_shuffled = ds_slice.sample(frac=1)
        slice_train = df_slice_shuffled.iloc[0:n_train+1]
        slice_test = df_slice_shuffled.iloc[(n_train+1):(n_train+n_test+1)]
        slice_valid = df_slice_shuffled.iloc[(n_train+n_test+1):]
        # Stack the datasets under the empty frames
        train_set  = pd.concat([train_set, slice_train])
        test_set = pd.concat([test_set, slice_test])
        valid_set = pd.concat([valid_set, slice_valid])
    
    for unseen_mol in unseen:
        ds_slice = dataset.iloc[dataset.CPD_NAME.values == unseen_mol]
        ood_set = pd.concat([ood_set, ds_slice])
    # Set the column names
    train_set.columns, test_set.columns, valid_set.columns, ood_set.columns = df.columns, df.columns, df.columns, df.columns
    return train_set, test_set, valid_set, ood_set

train_set, test_set, valid_set, ood_set = split_train_val_test(dataset, seen, unseen)

In [37]:
train_set.to_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-train.csv')
test_set.to_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-test.csv')
valid_set.to_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-val.csv')
ood_set.to_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-ood.csv')

## Create training, test and valdation npy files with labels

In [38]:
# Read files again if not in memory
# Get datasplit files
file_train = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-train.csv')
file_test = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-test.csv')
file_val = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-val.csv')
file_ood = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata_processed/datasplit-ood.csv')

In [39]:
np.intersect1d(file_train.SAMPLE_KEY.values, file_val.SAMPLE_KEY.values)

array([], dtype=object)

### Assay labels

In [40]:
# Get the label matrix
label_mat = mmread('/home/icb/alessandro.palma/imCPA/data/metadata/label_matrix/label-matrix.mtx').tocsr()
col_labels = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/label_matrix/column-assay-index.csv')
row_labels = pd.read_csv('/home/icb/alessandro.palma/imCPA/data/metadata/label_matrix/row-compound-index.csv')

In [41]:
def add_assay_label_to_samples(file,
                        label_mat):
    """
    To each observation add its vector of labels 
    """
    file_rows = file.ROW_NR_LABEL_MAT  # Row on the assay label matrix 
    y_assays = []  # List of the assay labels for the dataset 
    for row in file_rows:
        y_assay = label_mat[row].todense()
        y_assays.append(np.array(y_assay))
    return np.array(y_assays).squeeze()

In [42]:
train_assay_labels = add_assay_label_to_samples(file_train, label_mat)
test_assay_labels = add_assay_label_to_samples(file_test, label_mat)
val_assay_labels = add_assay_label_to_samples(file_val, label_mat)
ood_assay_labels = add_assay_label_to_samples(file_ood, label_mat)

In [43]:
len(train_assay_labels)

204852

In [44]:
# Save labels
np.savez('/home/icb/alessandro.palma/imCPA/data/metadata_processed/labels_train.npz', assay_labs = train_assay_labels)
np.savez('/home/icb/alessandro.palma/imCPA/data/metadata_processed/labels_test.npz', assay_labs = test_assay_labels)
np.savez('/home/icb/alessandro.palma/imCPA/data/metadata_processed/labels_valid.npz', assay_labs = val_assay_labels)
np.savez('/home/icb/alessandro.palma/imCPA/data/metadata_processed/labels_ood.npz', assay_labs = ood_assay_labels)

## Prova 

In [None]:
data_path = '/home/icb/alessandro.palma/data/splits'

In [None]:
class CellPaintingDataset:
    """
    Dataset class for image data 
    """
    def __init__(self, data_path, transform, device='cuda', return_labels=False):
        """
        Params:
        -------------
            :data_path: the repository where the data is stored 
            :transform: a pytorch transform object to apply augmentation to the data 
            :data_index_path: path to .npz object with sample names, molecule names, molecule smiles and the assay labels
            :return_labels: bool to assess whether to return labels together with observations in __getitem__
        """    
        assert os.path.exists(data_path), 'The data path does not exist'
        assert os.path.exists(drug_path), 'The drug path does not exist'

        # Read train, validation and test sets 
        self.data_path = data_path 
        self.transform = transform 
        self.device = device 
        self.return_labels = return_labels

        # Read the drug names
        print('Load the data')
        self.fold_datasets = self.read_folds()
        
        # Take the seen molecules from the training, test and valid set and map them to indices 
        seen_compounds = np.unique(self.fold_datasets['train']['mol_names'])
        unseen_compounds = np.unique(self.fold_datasets['ood']['mol_names'])
        assert len(seen_compounds)+len(unseen_compounds) == 10600
        
        seen_compounds = sorted(seen_compounds)
        mol2label = {d:i for d,i in zip(seen_compunds, range(len(seen_compunds)))}
        
        # Onehot encoder 
        encoder_drug = OneHotEncoder(sparse=False, categories=[seen_compounds])
        encoder_drug.fit(np.array(seen_compounds).reshape((-1,1)))
        
        # Initialize the datasets 
        fold_datasets = {'train': CellPaintingFold('train', self.fold_datasets['train'], encoder_drug, mol2label, self.transform, self.return_labels),
                         'val': CellPaintingFold('val', self.fold_datasets['val'], encoder_drug, mol2label, self.transform, self.return_labels),
                         'test': CellPaintingFold('test', self.fold_datasets['test'], encoder_drug, mol2label, self.transform, self.return_labels),
                         'ood': CellPaintingFold('ood', self.fold_datasets['ood'], encoder_drug, mol2label, self.transform, self.return_labels)}
        
        
    def read_folds(self):
        """
        Extract the filenames of images in the train, test and validation sets from the 
        associated folder
        """
        # Get the file names and molecules of training, test and validation sets
        datasets = dict()
        for fold_name in ['train', 'val', 'test', 'ood']:
            datasets[fold_name] = {}
            # Fetch the data
            data_index_path = os.path.join(self.data_path, f'{fold_name}_data_index.npz')
            # Get the files with the sample splits and add them to the dictionary 
            fold_file, mol_names, mol_smiles, assay_labels  = self.get_files_and_mols_from_path(data_index_path=data_index_path)

            # Add the  important entries to the dataset
            datasets[fold_name]['file_names'] = fold_file
            datasets[fold_name]['mol_names'] = mol_names
            datasets[fold_name]['mol_smiles'] = mol_smiles
            datasets[fold_name]['assay_labels'] = assay_labels
        return datasets
    
    def get_files_and_mols_from_path(self, data_index_path): 
        """
        Load object with image names, molecule names and smiles 
        -------------------
        data_index_path: The path to the data index with information about molecules and sample names
        """
        assert os.path.exists(data_index_path), 'The data index file does not exist'
        # Load the index file 
        file = np.load(data_index_path, allow_pickle= True)

        file_names = file['filenames']
        mol_names = file['mol_names']
        mol_smiles =  file['mol_smiles']
        assay_labels = file['assay_labels']
        return file_names, mol_names, mol_smiles, assay_labels
    

class CellPaintingFold(Dataset):
    def __init__(fold, data, drug_encoder, mol2label, transform, return_labels = True):
        super(CellPaintingDataset, self).__init__() 
        
        # For each piece of the data create its own object
        self.file_names = data['file_names']
        self.mol_names = data['mol_names']
        self.mol_smiles = data['mol_smiles']
        self.assay_labels = data['assay_labels']
        
        self.drug_encoder = drug_encoder
        self.mol2label = mol2label
        
        self.transform = transform
        self.return_labels = return_labels 
    
        # One -hot encode molecules
        self.one_hot_drugs = self.drug_encoder.transform(np.array(self.mol_names.reshape((-1,1))))
        
        
    def __len__(self):
        """
        Total number of samples 
        """
        return len(self.file_names)
    
    
    def __getitem__(self, idx):
        """
        Generate one example datapoint 
        """
        img_file = self.file_names[idx]
        sample = img_file.split('-')[0]
        well = img_file.split('-')[1].split('-')[0]
        # Load image 
        with np.load(os.path.join(self.data_path, sample, well, f'{img_file}.npz'), allow_pickle = True) as f:
            img = f['arr_0']
        img = torch.from_numpy(img).to(torch.float)
        img = img.permute(2,0,1)  # Place channel dimension in front of the others 
        if self.transform != None:
            img = self.transform(img)
        
        if self.return_labels:
            return dict(X=img, 
                        file_name=img_file,
                        mol_name=self.mol_names[idx], 
                        mol_one_hot=self.one_hot_drugs[idx]
                        mol_smile=self.mol_smiles[idx],
                        assay_labels=self.assay_labels[idx])
        else:
            return dict(X=img)

    def sample(self, n, seed=42):
        """
        Sample random observations from the training set
        """
        np.random.seed(seed)
        # Pick indices
        idx = np.arange(len(self.file_names))
        idx_sample = np.random.choice(idx, n, replace=False)

        # Select from the the filenames at random
        subset_mol_one_hot = []
        subset_mol_smiles = []
        subset_assay_labels = []
        subset_file_names = []
        imgs = []
        
        for i in idx_sample:
            X, file_name, subset_mol_one_hot, mol_smile, assay_label = self.__getitem__(i).values()
            imgs.append(X.unsqueeze(0))
            subset_mol_one_hot.append(mol_name)
            subset_mol_smiles.append(mol_smile)
            subset_assay_labels.append(assay_label)
            subset_file_names.append(file_name)
        
        imgs = torch.cat(imgs, dim=0)
        return dict(X=imgs, 
                    file_name=subset_file_names,
                    mols_one_hot=subset_mol_one_hot, 
                    mol_smile=subset_mol_smiles,
                    assay_label=subset_assay_labels
                ) 
        