Import the packages and libraries used by this notebook

In [None]:
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import torch
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from torch import nn
import torch.optim as optim
import os
from dataloader import *
from model import *
from torchsummary import summary
from loss_functions import *
from utils import reconstruction_compare, create_loss_graph, dataframe_w_latentvecs, create_img_scatterplot, create_model_view_img_scatterplot
from utils import create_annotated_scatterplot, get_latent_vectors, get_growth_medium, get_plate_id, zoom_img_scatterplot
import hashlib
from sklearn.manifold import TSNE
from umap import UMAP
import re

Initialize the "arguments" dictionary, which contains many of the parameters, file paths, and simple variables used throughout the codebase

In [None]:
arguments = {
'DEVICE' : torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu"),

'CANDESCENCE' : "/home/data/refined/candescence",
'TLV' : "/home/data/refined/candescence/tlv",
'VAES' : "/home/data/refined/candescence/tlv/vaes",
'CUT_IMAGES' : os.path.join("/home/data/refined/candescence/tlv", "0.2-images_cut", "all_simple_statmatch"),
'SAVED_MODELS' : "/home/data/refined/candescence/tlv/saved_torch_models",
'METADATA_DIR' : os.path.join("/home/data/refined/candescence/tlv", "data_files", "Calb_Master_10062022.tsv"),
'GRAPH_SAVES' : "/home/data/refined/candescence/tlv/torch_model_graphs",
'MASTER_SAVES' : "/home/data/refined/candescence/tlv/saved_MASTERs",
'LOSS_SAVES' : "/home/data/refined/candescence/tlv/saved_loss_graphs",

'exp' : "testing_simple_statmatch_2",
'dataset_seed' : 9954,
'training_seed' : 9954,
'image_type' : 'non-wash', # either "wash" or "non-wash"; can only specify this if taking images from all-final or a directory containing both wash and non-wash colonies
'train_num' : 6400,
'val_num' : 1600,
'test_num' : 3200,

'batch_size' : 64,
'kernel_size' : 3,
'latent_dim' : 6,
'intermediate_dim' : 60,
'epochs' : 20,
'learning_rate' : 0.000125,
'weight_decay' : 1.5e-5,
'kl_weight' : 0.4,
'MSE_weight' : 2.0,
'size_bins' : 5,
'intensity_bins' : 10,
'OH_in_decoder' : False
}

print(f"Device in use: {arguments['DEVICE']}")

Here we wipe any previous instance of the log file so that only the current run is recorded

In [None]:
open('dataloading.log', 'w').close()
# Note: all logging information is stored in dataloading.log, NOT just the logging information from the dataloader 

Next, we create the dataloaders for the training, validation, and test datasets - to be used later during training.

In [None]:
train_dataloader, val_dataloader, test_dataloader = create_dataloader(arguments)

print(len(train_dataloader))

Here we instantiate a vae object using the class defined above, and also define our optimizer. Finally, we can send the model to the GPU/device used in training.

In [None]:
torch.manual_seed(arguments['training_seed'])

# Sample the third element in one of the dataloaders, which represents the length of the colony one-hot encodings
# We pass this to our VAE class' __init__ so as to create a dynamic architecture that allows for changes to the one-hot encoding length
sample_img, sample_OH = next(iter(train_dataloader))
OH_len = sample_OH.shape[2]

vae = VAE(arguments, OH_len, OH_in_decoder=arguments['OH_in_decoder'])

optimizer = torch.optim.Adam(vae.parameters(), lr=arguments['learning_rate'], weight_decay=arguments['weight_decay']) # potentially make weight decay into a modifiable hyperparameter above?

# device = torch.device("cuda:7")
vae.to(arguments['DEVICE'])

# summary(vae, (1,135,135))

Now we define functions which handle the single epoch training and validation for the model, each returning the loss at the end of that epoch

In [None]:
def training_epoch(vae, device, dataloader, optimizer, arguments):

    vae.train() # sets training mode for VAE's encoder and decoder
    train_loss = 0.0
    # imgs, OH_stuff = next(iter(dataloader))
    # print("imgs: ", imgs)
    # print("OH stuff: ", OH_stuff)
    for x, one_hot in dataloader:
        x = x.to(device)
        one_hot = one_hot.to(device)
        x_hat = vae(x, one_hot)
        # print(f'input x: {x}')
        # print(f'output x_hat: {x_hat}') ## x_hat gives tensors filled with nans
        # print(f"x min mean max: {x.min()}, {x.mean()}, {x.max()}")

        # Get loss with loss function
        loss = get_MSE_kl_loss(x, x_hat, vae.encoder.sigma, vae.encoder.mu, arguments)

        # Backward pass / weights modification
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print('\t Single batch train loss: %f' % (loss.item()))
        train_loss += loss.item()

    return train_loss / len(dataloader.dataset)


def validation_epoch(vae, device, dataloader, arguments):
    vae.eval() # set evaluation mode
    val_loss = 0.0
    with torch.no_grad(): # no need to track gradients
        for x, one_hot in dataloader:
            x = x.to(device)
            one_hot = one_hot.to(device)

            encoded_data = vae.encoder(x, one_hot) # pointless??
            x_hat = vae(x, one_hot)
            loss = get_MSE_kl_loss(x, x_hat, vae.encoder.sigma, vae.encoder.mu, arguments)

            val_loss += loss.item()
    
    return val_loss / len(dataloader.dataset)
            

Now we run the training/validation loop, across the number of epochs specified above in the arguments/parameters

In [None]:
train_losses = []
val_losses = []

for epoch in range(arguments['epochs']):
   train_loss = training_epoch(vae,arguments['DEVICE'],train_dataloader,optimizer,arguments)
   val_loss = validation_epoch(vae,arguments['DEVICE'],val_dataloader,arguments)

   train_losses.append(train_loss)
   val_losses.append(val_loss)

   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, arguments['epochs'], train_loss, val_loss))
   reconstruction_compare(arguments, test_dataloader.dataset, vae.encoder, vae.decoder, n=10)

# Create/plot the losses over training period
create_loss_graph(train_losses, val_losses, arguments)

model_save_path = os.path.join(arguments['SAVED_MODELS'], f"{arguments['exp']}.pth")
torch.save(vae.state_dict(), model_save_path)

In the next cell, we will call upon utils functions to try and visualize / plot the latent space of the model, by using t-SNE and UMAP dimensionality reduction to reduce its dimensions to two. This can then easily be displayed on a scatterplot.

In [None]:
############### PUT MODEL IN EVALUATION/INFERENCE MODE: ######################
vae.eval()
###############                                         ######################

test_dataset, test_indices = torch.load(os.path.join(arguments['VAES'], f'{arguments["dataset_seed"]}_test'))
full_dataset = test_dataset.dataset.dataset

MASTER = dataframe_w_latentvecs(arguments, full_dataset, test_indices, vae)

### Perform t-SNE dimensionality reduction on latent space ### 

features = MASTER.filter(regex=('^V\d+'))
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(features)

MASTER['tsne-2d-one'] = tsne_results[:,0]
MASTER['tsne-2d-two'] = tsne_results[:,1]

### Perform UMAP dimensionality reduction on latent space ### 

umap = UMAP(n_components=2, random_state=42, min_dist=0.2, n_neighbors=15)
umap_results = umap.fit_transform(features)

MASTER['umap-2d-one'] = umap_results[:,0]
MASTER['umap-2d-two'] = umap_results[:,1]


# Save the MASTER dataframe for later use whenever needed
file_name = arguments['exp'] + "_MASTER.csv"
MASTER.to_csv(os.path.join(arguments['MASTER_SAVES'], file_name), index=False)


Here we load in the saved MASTER file for the experiment we want to analyze, and generate scatterplots representing the latent space with the original encoded images instead of points. Both t-SNE and UMAP can be used for dimensionality reduction.

In [None]:
# View the image-overlay scatterplot

exp_for_analysis = arguments['exp']
file_name = exp_for_analysis + "_MASTER.csv"
MASTER = pd.read_csv(os.path.join(arguments['MASTER_SAVES'], file_name))

# SPECIFY DIMENSIONALITY REDUCTION TYPE AS UMAP = 'umap', t-SNE = 'tsne'
create_img_scatterplot(MASTER, arguments, reduction_technique = 'tsne')
create_img_scatterplot(MASTER, arguments, reduction_technique = 'umap')
# create_model_view_img_scatterplot(MASTER, test_dataloader.dataset, arguments)

In this next code block, we load in the saved MASTER file we want (by giving the experiment name), and can create annotated/colored scatterplots based off of any category of information that exists within the MASTER/METADATA (e.g. medium, geographical origin, strain, etc.)

In [None]:
### LIST OF ANNOTATION CATEGORIES:
# 'Plate'
# 'medium'
# 'day'
# 'Broad.Niche'
# 'Geography.General'
# many more - see file stored in metadata_dir for all categories
### END LIST
# NOTE: Many boxes contain 'nan', 'NaN', or 'unknown' values - these are ideally filtered out or displayed as one color

exp_for_analysis = arguments['exp']
file_name = exp_for_analysis + "_MASTER.csv"
MASTER = pd.read_csv(os.path.join(arguments['MASTER_SAVES'], file_name))

# Select only a specific plate to be colored for visualization (if wanted):
# conditions = (MASTER['Plate'] != 14) | (MASTER['day'] != 5) | (MASTER['medium'] != 'serum')
# MASTER.loc[conditions, 'Plate'] = 0

create_annotated_scatterplot(MASTER, 'medium')


We can call on the zoom_img_scatterplot() function to zoom in on a region of interest in either the UMAP or t-SNE scatterplots. The size of the zoomed colonies can be modified in the code for the function.

In [None]:
# Coordinates argument is in order: (xmin, xmax, ymin, ymax). 
# An additional final argument, "annotate=", can be passed to the function in order to pick an annotation style for the zoomed box. By default, it is "None".
zoom_img_scatterplot(MASTER, arguments, 'tsne', (10, 42, -25, -10),annotate=None)

Next are a series of functions and function calls, whose primary purpose is to take a deeper look at and only plot certain specific colonies or plates that are of interest

In [None]:
### NOTE: Right now you need to run the entire model successfully once before evaluating this cell. TODO: Add
# ability to load in the state of a previous model into a variable called "inference_model" or something, and then
# pass that model to these functions. As of now that would still require running the first few cells though (NOT the
# whole training loop though)
### 


### specific_reconstruction_compare() takes a select set of colonies, and returns a UMAP latent spread of only those colonies,
### and a the reconstruction comparison of some of them


# Create a master that maps every possible image to the metadata given, but does not find latent vector for each image
# (saves on computation and time). Also creates unique "colony IDs" for each colony.
def create_and_save_full_master(arguments, full_dataset):
    file_dir = arguments['METADATA_DIR']
    meta_table = pd.read_csv(file_dir, sep='\t',encoding='ISO-8859-1')

    dataset_all_indices = list(range(len(full_dataset)))
    filenames = full_dataset.get_image_filename(dataset_all_indices)
    df = pd.DataFrame({'filenames': filenames})

    # Extract plate number
    plate_numbers = df['filenames'].str.extract('Pl(\d+)|P(\d+)|Pwt', expand=False)
    df['Plate'] = np.where(plate_numbers[0].notnull(), plate_numbers[0], plate_numbers[1])
    df['Plate'] = df['Plate'].fillna(-1).astype(int)
    # Extract medium
    df['medium'] = df['filenames'].str.extract('(spider|ctrl|spdr|control|serum|RPMI|YPD)', expand=False)
    df['medium'] = df['medium'].replace({'spdr': 'spider', 'ctrl': 'control'})
    # Extract day
    df['day'] = df['filenames'].str.extract('day(\d+)', expand=False).astype(int)
    # Extract replicate
    df['replicate'] = df['filenames'].str.extract(r'_(\d+)-', expand=False).fillna('-1')
    df['replicate'] = df['replicate'].astype(int)
    # Extract position
    df['Position'] = df['filenames'].str.extract('-(.*)\\.', expand=False)

    # Create a function to format position
    def get_row_col(pos):
        # print(pos)
        m = re.search(r"r(\d+)-c(\d+)", pos)
        # print(m)
        if m is not None:
            row = chr(int(m.group(1)) + 64)  # converting to ASCII
            col = m.group(2)
            return row + col
        else:
            return pos

    df['Position'] = df['Position'].apply(get_row_col)

    full_MASTER = pd.merge(df, meta_table, on=["Plate", "Position"], how='inner')

    ## Now add columns for the unique ID of each colony, and for the plate name (containing the 96 colonies) for each colony
    def generate_unique_id(filename):
        hash_out = hashlib.sha256(filename.encode())
        unique_id = hash_out.hexdigest()[:8]
        return unique_id

    def extract_plate_name(full_path):
        base_name = os.path.basename(full_path)
        name_without_extension = os.path.splitext(base_name)[0]
        if name_without_extension[-2] == 'c': # essentially checks if column is 10 or higher, as that changes the length
            plate_name = name_without_extension[:-6] # Chop off the last 6 characters
        else:
            plate_name = name_without_extension[:-7] # Chop off the last 7 characters
        return plate_name

    full_MASTER['unique_ID'] = full_MASTER['filenames'].apply(generate_unique_id)
    full_MASTER['plate_name'] = full_MASTER['filenames'].apply(extract_plate_name)

    file_name = "FULL_MASTER.csv"
    full_MASTER.to_csv(os.path.join(arguments['MASTER_SAVES'], file_name), index=False)

# Call the function defined above, and load in the FULL_MASTER
create_and_save_full_master(arguments, full_dataset)
FULL_MASTER = pd.read_csv(os.path.join(arguments['MASTER_SAVES'], "FULL_MASTER.csv")) 
# Creates a mapping that groups together all the colony IDs for each plate (with 96 colonies)
plate_to_colony_mapping = FULL_MASTER.groupby('plate_name')['unique_ID'].apply(list).to_dict() 

def get_col_ids_from_plates(plate_list):
    id_list = []
    for plate in plate_list:
        ids = plate_to_colony_mapping.get(plate, [])
        id_list.extend(ids)
    return id_list

plates_of_interest = ["P1_control_day5_1",
                      "P17_YPD_day2_2",
                      "Pl10_RPMI_day5_2"]

colonies_to_analyze = get_col_ids_from_plates(plates_of_interest)

# Used in analyze_specific_colonies() to reconstruct specific colonies
def specific_reconstruction_compare(arguments, reconstruction_tuple, encoder, decoder, ids_to_reconstruct):

    if ids_to_reconstruct is None: # If no ids are given, reconstruct 10 random colonies from the plate(s)
        plt.figure(figsize=(16,4.5)) 
        n = 10
        random_indices = np.random.choice(len(reconstruction_tuple[0]), size=n, replace=False)
        for i, idx in enumerate(random_indices):
            ax = plt.subplot(2, n, i+1)
            img = reconstruction_tuple[0][idx][0].unsqueeze(0).to(arguments['DEVICE'])
            OH_tensor = reconstruction_tuple[0][idx][1]
            OH_tensor = OH_tensor.to(arguments['DEVICE'])
            # print(OH_tensor)
            # print(img.shape)
            # OH_tensor = OH_tensor.unsqueeze(-1) # Comment this out omce using variables again
            encoder.eval()
            decoder.eval()
            with torch.no_grad():
                rec_img  = decoder(encoder(img, OH_tensor), OH_tensor) ## ADD INPUT FOR OH TENSOR
            plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)  
            if i == n//2:
                ax.set_title('Original images')
            ax = plt.subplot(2, n, i + 1 + n)
            plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)  
            if i == n//2:
                ax.set_title('Reconstructed images')

        plt.show()
    else:
        plt.figure(figsize=(16,4.5)) 
        id_list = reconstruction_tuple[1]
        for i, id in enumerate(ids_to_reconstruct):
            idx = id_list.index(id)  # find the index of the id in id_list
            ax = plt.subplot(2, len(ids_to_reconstruct), i+1)
            img = reconstruction_tuple[0][idx][0].unsqueeze(0).to(arguments['DEVICE'])
            OH_tensor = reconstruction_tuple[0][idx][1]
            OH_tensor = OH_tensor.to(arguments['DEVICE'])
            encoder.eval()
            decoder.eval()
            with torch.no_grad():
                rec_img  = decoder(encoder(img, OH_tensor), OH_tensor)
            plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)  
            if i == len(ids_to_reconstruct)//2:
                ax.set_title('Original images')
            ax = plt.subplot(2, len(ids_to_reconstruct), i + 1 + len(ids_to_reconstruct))
            plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)  
            if i == len(ids_to_reconstruct)//2:
                ax.set_title('Reconstructed images')

        plt.show()


# This function should take only the colonies of interest for analysis, and create a mini-master with just those colonies
# and their latent vectors, and then get the UMAP/t-SNE coords for each, and do whichever visual analyses we want
def analyze_specific_colonies(full_master, full_dataset, vae, col_id_list, reduction_technique, anno_type, ids_to_reconstruct, arguments):
    
    dataset_all_indices = list(range(len(full_dataset)))
    all_filenames = full_dataset.get_image_filename(dataset_all_indices)

    filtered_master = full_master[full_master['unique_ID'].isin(col_id_list)]
    
    # get the latent vectors for the colonies of interest
    filenames_set = set(filtered_master['filenames']) # for faster lookup
    indices_of_interest = [i for i, filename in enumerate(all_filenames) if filename in filenames_set]
    df_with_latentvecs = get_latent_vectors(dataset=full_dataset,indices=indices_of_interest,encoder=vae.encoder,arguments=arguments)
    filtered_master = filtered_master.merge(df_with_latentvecs, left_on='filenames', right_on='file_name', how='left')
    filtered_master.drop(columns=['filenames'], inplace=True) # delete duplicate file name column remnant from merge
    
    # Create a tuple in the form ((dataset_tuple), unique_id) which can be passed to the reconstructor function
    tuples_of_interest = [full_dataset[i] for i in indices_of_interest]
    filtered_filenames = [all_filenames[i] for i in indices_of_interest]
    filename_to_id = filtered_master.set_index('file_name')['unique_ID'].to_dict() # for fast lookup
    unique_ids = [filename_to_id[filename] for filename in filtered_filenames]
    reconstruct_tuple = (tuples_of_interest, unique_ids)

    # Add columns for the UMAP and t-SNE coordinates:
    ### Perform t-SNE dimensionality reduction on latent space ### 
    features = filtered_master.filter(regex=('^V\d+'))
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)
    filtered_master['tsne-2d-one'] = tsne_results[:,0]
    filtered_master['tsne-2d-two'] = tsne_results[:,1]
    ### Perform UMAP dimensionality reduction on latent space ### 
    umap = UMAP(n_components=2, random_state=42, min_dist=0.2, n_neighbors=15)
    umap_results = umap.fit_transform(features)
    filtered_master['umap-2d-one'] = umap_results[:,0]
    filtered_master['umap-2d-two'] = umap_results[:,1]

    # Generate image scatter and annotated scatter using existing functions
    create_img_scatterplot(filtered_master,arguments,reduction_technique)
    create_annotated_scatterplot(filtered_master,anno_type)

    specific_reconstruction_compare(arguments, reconstruct_tuple, vae.encoder, vae.decoder, ids_to_reconstruct=ids_to_reconstruct) 
    specific_reconstruction_compare(arguments, reconstruct_tuple, vae.encoder, vae.decoder, ids_to_reconstruct=None) # Do it again with 10 random colonies

    return filtered_master

analyzed_colonies = analyze_specific_colonies(FULL_MASTER,full_dataset,vae,colonies_to_analyze,'tsne',"medium",["36f3a141","1c89d1a9","9f449d55"],arguments)
