In [1]:
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.insert(0, parent_dir)

import pandas as pd 
import numpy as np 
import re
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from directories import *
from VAE_models.VAE_model import *
from VAE_models.VAE_model_enhanced import *
from VAE_models.VAE_model_2 import *
from VAE_models.VAE_model_single import *
from training import *
from extras import *
from sklearn.decomposition import PCA
from collections import defaultdict

plt.style.use('ggplot')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"Current working directory: {current_dir}")
print(f"Parent directory: {parent_dir}")

<a id='data_exploration'></a>
# 1) Data loading 

Loading essential genes and the dataset files 

In [None]:
essential_genes = pd.read_csv(PAPER_ESSENTIAL_GENES)

In [None]:
large_data = pd.read_csv(TEN_K_DATASET, index_col=[0], header=[0])

Turning all sample names uppercase for consistency

In [None]:
large_data.columns = large_data.columns.str.upper()

In [None]:
large_data

In [None]:
large_data.sum(axis=1).sort_values()

Importing the phylogroup metadata 

In [None]:
data_without_lineage = large_data.drop(index=['Lineage'])
large_data_t = np.array(data_without_lineage.transpose())

print(f"Full dataset shape: {large_data_t.shape}")


# 2) Data preprocessing

## 2.1) Dataset preprocessing 

In [None]:
phylogroup_data = pd.read_csv(TEN_K_DATASET_PHYLOGROUPS, index_col=[0], header=[0])

In [None]:
merged_df = pd.merge(data_without_lineage.transpose(), phylogroup_data, how='inner', left_index=True, right_on='ID')

In [None]:
data_array_t = np.array(merged_df.iloc[:, :-1])
phylogroups_array = np.array(merged_df.iloc[:, -1])

In [None]:
print("Checking dataset shapes")
print(f"Values array: {data_array_t.shape}")
print(f"Phylogroups array: {phylogroups_array.shape}")

## 2.2) Conversing the dataset into splits and dataloaders

In [None]:
# Convert to PyTorch tensor
data_tensor = torch.tensor(data_array_t, dtype=torch.float32)

# Split into train and test sets
train_data, temp_data, train_labels, temp_labels = train_test_split(data_tensor, phylogroups_array, test_size=0.3, random_state=12345)
val_data, test_data, val_labels, test_labels = train_test_split(temp_data, temp_labels, test_size=0.3333, random_state=12345)
test_phylogroups = test_labels

# Set batch size
batch_size = 32

# TensorDataset
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)
test_dataset = TensorDataset(test_data)

# Set laoders
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

In [None]:
print(f"Train data shape {train_data.shape}")
print(f"Test data shape {test_data.shape}")
print(f"Val data shape {val_data.shape}")

# 3) Essential genes manipulatioins

Creating an array of essential genes fromt the paper and flattening it 

In [None]:
essential_genes_array = np.array(essential_genes).flatten()

In [None]:
print(f"Total number of essential genes present in the paper: {len(essential_genes_array)}")

Creating a gene mask for the essential arrays for more optimal counting of the essential arrays present in the samples 

In [None]:
all_genes = merged_df.columns

In [None]:
essential_genes_mask = np.isin(all_genes, essential_genes_array)

In [None]:
print(f"Total number of essential genes present in the dataset: {np.sum(essential_genes_mask)}")

Figuring out which genes are not present in the dataset

In [None]:
subset_not_in_essential_genes_mask = essential_genes[~np.isin(np.array(essential_genes), np.array(all_genes[essential_genes_mask]))]

Figuring out which genes are present in the dataset

In [None]:
subset_in_essential_genes_mask = essential_genes[np.isin(np.array(essential_genes), np.array(all_genes[essential_genes_mask]))]

Final results

In [None]:
absent_genes = np.array(subset_not_in_essential_genes_mask).flatten()

In [None]:
print(f"Number of genes not present in the dataset: {len(absent_genes)}")

In [None]:
present_genes = np.array(subset_in_essential_genes_mask).flatten()

In [None]:
print(f"Number of genes present in the dataset: {len(present_genes)}") 

Seeing if the genes split into multiple parts in the dataset are the essential genes

In [None]:
matched_columns = []

for gene in absent_genes:
    pattern = re.compile(f"{gene}")
    matches = [col for col in merged_df.columns if pattern.match(col) and col not in present_genes]
    matched_columns.extend(matches)


divided_genes = np.array(matched_columns)
print(divided_genes)
print(len(divided_genes))


Manually creating the array of genes which is divided into chunks

In [None]:
divided_genes_prefixes = ['msbA', 'fabG', 'lolD', 'topA', 'metG', 'fbaA', 'higA', 'lptB', 'ssb',  'lptG', 'dnaC'] # 'higA-1', 'higA1','higA-2', 'ssbA' dont count 

In [None]:
not_present = np.array(list(set(absent_genes) - set(divided_genes_prefixes)))

In [None]:
print(f"Genes which are still not present in the dataset after prefix extraction: {not_present}")
print(f"Total number: {len(not_present)}")

Creating a new array of the genes (both sigle name and didived) present in the dataset 

In [None]:
combined_array = np.concatenate((present_genes, divided_genes))

In [None]:
print(f"Total umber of genes that count as essential in the dataset: {len(combined_array)}")

Creating a new gene mask including the divided essential genes 

In [None]:
essential_genes_mask = np.isin(all_genes, combined_array)

In [None]:
essential_genes_df = merged_df.loc[:, essential_genes_mask].copy()

In [None]:
essential_genes_df

In [None]:
gene_sums = essential_genes_df.sum()
zero_sum_genes = gene_sums[gene_sums == 0].index.tolist()
print(f"Genes that are not present (overall 0 in all samples): {zero_sum_genes}")

Dataframe of just absent essential genes (including the ones that are split up)

In [None]:
absent_essential_genes_df = pd.DataFrame()

for prefix in absent_genes:
    cols_to_merge = essential_genes_df.filter(regex=f'^{prefix}')
    absent_essential_genes_df[prefix] = (cols_to_merge.sum(axis=1) > 0).astype(int)

In [None]:
absent_essential_genes_df

Datafarme of the genes that are divided into chunks

In [None]:
intermediate = essential_genes_df.drop(columns=divided_genes)

In [None]:
intermediate

Adding the absent essential genes that are present in the dataframe to the overall dataframe of the genes presemt in the datatframe

In [None]:
row_sums = absent_essential_genes_df.sum(axis=0)
columns_to_add = absent_essential_genes_df.columns[row_sums != 0]

In [None]:
columns_to_add

In [None]:
absent_essential_genes_df[columns_to_add]

In [None]:
absent_essential_genes_df[columns_to_add].columns

Adding these selected columns to the original DataFrame

In [None]:
for col in absent_essential_genes_df[columns_to_add].columns:
    intermediate[col] = absent_essential_genes_df[col]

Intermediate dataframe to plot the frequency of the present in the dataframe genes:

In [None]:
intermediate

In [None]:
intermediate.sum(axis=0)

In [None]:
np.save('/Users/anastasiiashcherbakova/git_projects/masters_project/data/essential_gene_in_ds.npy', intermediate.columns.to_list())

In [None]:
EG_distribution = intermediate.sum(axis=1)
mean = np.mean(EG_distribution)
median = np.median(EG_distribution)
min_value = np.min(EG_distribution)
max_value = np.max(EG_distribution)

In [None]:
plt.figure(figsize=(10,8))
plt.hist(EG_distribution, color='darkorchid', bins=20)
plt.xlabel('Essential gene number')
plt.ylabel('Frequency')
plt.axvline(mean, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}')
plt.axvline(median, color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}')
dummy_min = plt.Line2D([], [], color='black',  linewidth=2, label=f'Min: {min_value:.2f}')
dummy_max = plt.Line2D([], [], color='black', linewidth=2, label=f'Max: {max_value:.2f}')

handles = [plt.Line2D([], [], color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}'),
        plt.Line2D([], [], color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}'),
        dummy_min, dummy_max]
plt.legend(handles=handles)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/EG_number.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
EG_distribution = intermediate.sum(axis=1) / intermediate.sum(axis=1).max()
mean = np.mean(EG_distribution)
median = np.median(EG_distribution)
min_value = np.min(EG_distribution)
max_value = np.max(EG_distribution)

In [None]:
plt.figure(figsize=(10,8))
plt.hist(EG_distribution, color='darkorchid', bins=20)
plt.xlabel('Essential gene proportion in the dataset')
plt.ylabel('Frequency')
plt.axvline(mean, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}')
plt.axvline(median, color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}')
dummy_min = plt.Line2D([], [], color='black',  linewidth=2, label=f'Min: {min_value:.2f}')
dummy_max = plt.Line2D([], [], color='black', linewidth=2, label=f'Max: {max_value:.2f}')

handles = [plt.Line2D([], [], color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}'),
        plt.Line2D([], [], color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}'),
        dummy_min, dummy_max]
plt.legend(handles=handles)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/EG_number_proportion.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
datatset_EG = list(intermediate.columns)

In [None]:
def extract_prefix(gene):
    match = re.match(r"([a-zA-Z0-9]+)", gene)
    if match:
        return match.group(1)
    return gene

# Group gene positions by their prefix
groups_of_gene_positions = defaultdict(list)
for idx, gene in enumerate(all_genes):
    prefix = extract_prefix(gene)
    groups_of_gene_positions[prefix].append(idx)

# Convert defaultdict to a regular dict
groups_of_gene_positions = dict(groups_of_gene_positions)

# Print the dictionary to verify
for prefix, positions in groups_of_gene_positions.items():
    print(f"{prefix}: {positions}")

Precompute essential gene positions

In [None]:
essential_gene_positions = {}
for gene in essential_genes_array:
    if gene in groups_of_gene_positions.keys():
        essential_gene_positions[gene] = groups_of_gene_positions[gene]

In [None]:
essential_gene_positions

Calculating the abundance of essential genes in the dataset 

In [None]:
essential_gene_abundance = pd.Series(0, index=essential_genes_array)

column_names = merged_df.columns

for gene, positions in essential_gene_positions.items():
    if len(positions) == 1:
        pos = positions[0]
        column_name = column_names[pos]
        essential_gene_abundance[gene] = merged_df[column_name].sum()
    else:
        column_subset = [column_names[pos] for pos in positions]
        essential_gene_abundance[gene] = merged_df[column_subset].sum(axis=1).sum()


In [None]:
gene_sums = intermediate.sum()
mean = np.mean(gene_sums)
median = np.median(gene_sums)
min_value = np.min(gene_sums)
max_value = np.max(gene_sums)

In [None]:
plt.figure(figsize=(10, 10))
plt.hist(gene_sums, color='violet')
plt.xlabel('Essential gene Abundance')
plt.ylabel('Frequence')
plt.axvline(mean, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}')
plt.axvline(median, color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}')
dummy_min = plt.Line2D([], [], color='black',  linewidth=2, label=f'Min: {min_value:.2f}')
dummy_max = plt.Line2D([], [], color='black', linewidth=2, label=f'Max: {max_value:.2f}')

handles = [plt.Line2D([], [], color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.2f}'),
        plt.Line2D([], [], color='b', linestyle='dashed', linewidth=2, label=f'Median: {median:.2f}'),
        dummy_min, dummy_max]
plt.legend(handles=handles)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_frequency.pdf", format="pdf", bbox_inches="tight")

In [None]:
print(f"Minimal gene abundance: {gene_sums.min()}")

# 4) Training of full dataset

## 4.1) Full dataset (base model)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 1024
latent_dim = 32
path_to_model = '/Users/anastasiiashcherbakova/Desktop/2_bigdataset/2_bigdataset/8_final_dataset_new_params/saved_KL_annealing_VAE_BD_100.pt'

model, binary_generated_samples = load_model(input_dim, hidden_dim, latent_dim, path_to_model)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_8_final_dataset_new_params.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_8_final_dataset_new_params.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
plot_color = "violet"
figure_name ="/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_8_final_dataset_new_params.pdf"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

----------

## 4.1) Full dataset (enhanced model with fropout layers in decoder)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 1024
latent_dim = 32
path_to_model = "/Users/anastasiiashcherbakova/Desktop/2_bigdataset/2_bigdataset/9_final_dataset_enhanced/saved_KL_annealing_VAE_BD_100.pt"

model, binary_generated_samples = load_model_enhanced(input_dim, hidden_dim, latent_dim, path_to_model)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_9_final_dataset_enhanced.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_9_final_dataset_enhanced.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_9_final_dataset_enhanced.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

## 5) Exploring ways to minimise genome size (new loss)

## 5.1) New loss (VAE v1)

just ran a test model with random initial parameters to see how it woudl perform with a new loss (gene abundance) included. L1 regularisation applied to the fetures in the model. one note: the new loss gama and beta params: 
beta_start = 0.1
beta_end = 1.0
gamma_start = 1.0
gamma_end = 0.1

In [None]:
# Load trained model
input_dim = 55039
hidden_dim = 512
latent_dim = 32
path_to_model = "/Users/anastasiiashcherbakova/git_projects/masters_project/genomes/models/saved_8_new_loss_model.pt"

model, binary_generated_samples = load_model(input_dim, hidden_dim, latent_dim, path_to_model)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_8_new_loss.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_8_new_loss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_8_new_loss.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

---------

## 5.1) New loss (VAE v1 but with a dropout layer in decoder)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 512
latent_dim = 32

path_to_model = "/Users/anastasiiashcherbakova/git_projects/masters_project/genomes/models/saved_8_new_loss_enhanced_model.pt"

model, binary_generated_samples = load_model_enhanced(input_dim, hidden_dim, latent_dim, path_to_model)


In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_8_new_loss_enhanced_model.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_8_new_loss_enhanced_model.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_8_new_loss_enhanced_model.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

------

## 5.1) New loss with no linear annealing (VAE v3)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 512
latent_dim = 32

path_to_model = "/Users/anastasiiashcherbakova/git_projects/masters_project/genomes/models/saved_11_non_linear_annealing_model.pt"

model, binary_generated_samples = load_model(input_dim, hidden_dim, latent_dim, path_to_model)


In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_11_non_linear_annealing.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_11_non_linear_annealing.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_11_non_linear_annealing.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

_______

## 5.1) New loss with genome size (VAE v2)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 512
latent_dim = 32

path_to_model = "/Users/anastasiiashcherbakova/git_projects/masters_project/genomes/models/saved_13_add_genome_size_model.pt"

model, binary_generated_samples = load_model(input_dim, hidden_dim, latent_dim, path_to_model)


In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_13_add_genome_size.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_13_add_genome_size_model.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_13_add_genome_size.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

______

## 5.1) New loss with non lenaer annelaing and genome size (VAE v4)

In [None]:
# Load trained model 
input_dim = 55039
hidden_dim = 512
latent_dim = 32

path_to_model = "/Users/anastasiiashcherbakova/git_projects/masters_project/genomes/models/saved_15_genome_size_and_cyclic_annealing_SCALED_model.pt"

model, binary_generated_samples = load_model(input_dim, hidden_dim, latent_dim, path_to_model)


In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/sampling_10000_genome_size_distribution_15_genome_size_and_cyclic_annealing_SCALED.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(binary_generated_samples, figure_name, plot_color)

In [None]:
latents = get_latent_variables(model, test_loader, device)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(latents)
df_pca = pd.DataFrame(data_pca, columns=['PC1', 'PC2'])
df_pca['phylogroup'] = test_phylogroups

# Plot the PCA results
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', hue = df_pca['phylogroup'], data=df_pca)
plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/pca_latent_space_visualisation_15_genome_size_and_cyclic_annealing_SCALED.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
essential_genes_count_per_sample = count_essential_genes(binary_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/essential_genes_15_genome_size_and_cyclic_annealing_SCALED.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(binary_generated_samples.sum(axis=1), essential_genes_count_per_sample, color='violet')

coefficients = np.polyfit(binary_generated_samples.sum(axis=1), essential_genes_count_per_sample, 1)
trendline = np.poly1d(coefficients)

plt.plot(binary_generated_samples.sum(axis=1), trendline(binary_generated_samples.sum(axis=1)), color='black', linewidth=2)

plt.xlabel('Essential genes number') 
plt.ylabel('Number of essential gene')

plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/GS_EG_15_genome_size_and_cyclic_annealing_SCALED.pdf", format="pdf", bbox_inches="tight")

In [None]:
num_samples = 10000
with torch.no_grad():
    z = torch.randn(num_samples, latent_dim)  # Sample from the standard normal distribution because the latent space follows normal distribution 
    generated_samples = model.decode(z).cpu().numpy() 

threshold = 0.5
binary_generated_samples = (generated_samples > threshold).astype(float)

print("Generated samples (binary):\n", binary_generated_samples)
print("\n")
print("Generated samples (sigmoid function output):\n", generated_samples)

In [None]:
total_ones = np.sum(binary_generated_samples, axis=1)
min_ones_index = np.argmin(total_ones)

latent_distances = np.linalg.norm(generated_samples - generated_samples[min_ones_index], axis=1)

closest_latent_index = np.argmin(latent_distances)

print(f"Closest latent vector (z): {z[closest_latent_index]}")
print(f"Generated sample from closest latent vector:\n {generated_samples[closest_latent_index]}")

In [None]:
sum(binary_generated_samples[min_ones_index])

Sampling additional samples from the minimal genomes region

In [None]:
z_of_interest = z[closest_latent_index] 
z_of_interest_tensor = torch.tensor(z_of_interest).unsqueeze(0)  

noise_std = 0.1

num_additional_samples = 10000  
with torch.no_grad():
    noise = torch.randn(num_additional_samples, latent_dim) * noise_std
    z_samples = z_of_interest_tensor + noise
    additional_generated_samples = model.decode(z_samples).cpu().numpy()


print("Additional generated samples:")
print(additional_generated_samples)

In [None]:
threshold = 0.5
additional_generated_samples = (additional_generated_samples > threshold).astype(float)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/additional_sampling_10000_genome_size_distribution_15_genome_size_and_cyclic_annealing_SCALED.pdf"
plot_color = "dodgerblue"

plot_samples_distribution(additional_generated_samples, figure_name, plot_color)


In [None]:
np.save('/Users/anastasiiashcherbakova/git_projects/masters_project/data/additional_generated_samples.npy', additional_generated_samples)

In [None]:
essential_genes_count_per_sample = count_essential_genes(additional_generated_samples, essential_gene_positions)

In [None]:
figure_name = "/Users/anastasiiashcherbakova/git_projects/masters_project/figures/additioinal_essential_genes_15_genome_size_and_cyclic_annealing_SCALED.pdf"
plot_color = "violet"

plot_essential_genes_distribution(essential_genes_count_per_sample, figure_name, plot_color)

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(additional_generated_samples.sum(axis=1), essential_genes_count_per_sample, color='violet')

coefficients = np.polyfit(additional_generated_samples.sum(axis=1), essential_genes_count_per_sample, 1)
trendline = np.poly1d(coefficients)

plt.plot(additional_generated_samples.sum(axis=1), trendline(additional_generated_samples.sum(axis=1)), color='black', linewidth=2)

plt.xlabel('Essential genes number') 
plt.ylabel('Number of essential gene')

plt.savefig("/Users/anastasiiashcherbakova/git_projects/masters_project/figures/additional_GS_EG_15_genome_size_and_cyclic_annealing_SCALED.pdf", format="pdf", bbox_inches="tight")

---------

# 6) Creating lists of lists with all genes in the sampled genomes

In [None]:
def extract_prefix(gene):
    match = re.match(r"([a-zA-Z0-9]+)", gene)
    if match:
        return match.group(1)
    return gene

# Step 1: Get the top 100 essential gene counts
top_100_values = np.sort(essential_genes_count_per_sample)[-100:][::-1]

# Step 2: Find the sequence indices in the array
sequence_indices = []
for value in top_100_values:
    indices = np.where(essential_genes_count_per_sample == value)[0]
    sequence_indices.extend(indices)

# Ensure we only get the first 100 unique indices in case of duplicates
sequence_indices = sequence_indices[:100]

# Step 3: Get the samples from additional_generated_samples
samples = additional_generated_samples[sequence_indices]

# Step 4: Find what genes they have present
present_genes_lists = []
for sample in samples:
    present_genes = all_genes[:-1][sample == 1]
    present_genes_lists.append(present_genes)

# Step 5: Clean up the gene names and add essential genes
cleaned_genes_lists = []
for genes in present_genes_lists:
    cleaned_gene_names = [extract_prefix(name) for name in genes]
    cleaned_gene_names.extend(datatset_EG) 
    cleaned_genes_lists.append(cleaned_gene_names)

np.save('/Users/anastasiiashcherbakova/git_projects/masters_project/data/cleaned_genes_lists.npy', np.array(cleaned_genes_lists, dtype=object))

In [None]:
top_100_values

---------

# (EXTRA) 7) Comparing the two different essential genes arrays

This step was done early on, however, we also compared the number of essential genes in on this website (https://shigen.nig.ac.jp/ecoli/pec/) with the essnetial genes in the dataset and we figures out we shoudl use Goodall et. al. essential genes

In [None]:
file_path = '/Users/anastasiiashcherbakova/git_projects/masters_project/data/essential_genes_website.txt'

df = pd.read_csv(file_path, delimiter='\t')  

df.to_csv('/Users/anastasiiashcherbakova/git_projects/masters_project/data/essential_genes_website.csv', index=False)

essential_genes_website = pd.read_csv('/Users/anastasiiashcherbakova/git_projects/masters_project/data/essential_genes_website.csv')

In [None]:
essential_genes_website_array = np.array(essential_genes_website['Gene Name'])

In [None]:
essential_genes_mask = np.isin(all_genes, essential_genes_website_array)
essential_genes_df = merged_df.loc[:, essential_genes_mask].copy()

In [None]:
essential_genes_df

In [None]:
essential_genes_present_array = np.array(essential_genes_df.columns)
print(f"Essential genes present in the dataset: {len(essential_genes_present_array)}")

In [None]:
genes_missing = list(set(essential_genes_website_array) - set(essential_genes_present_array))

In [None]:
print(f"Missing genes: {len(genes_missing)}")

In [None]:
matched_columns = []

for gene in genes_missing:
    pattern = re.compile(f"{gene}")
    matches = [col for col in merged_df.columns if pattern.match(col) and col not in present_genes]
    matched_columns.extend(matches)


divided_genes = np.array(matched_columns)
print(divided_genes)
print(len(divided_genes))


In [None]:
missing_genes = ['ssb', 'dnaC', 'metG', 'fabG', 'lptB', 'msbA', 'fbaA', 'lolD', 'topA', 'lptG'] 

In [None]:
# Find the values that are only in essential genes form the website
unique_in_array1 = np.setdiff1d(essential_genes_website_array, essential_genes_array)

# Find the values that are only in essential genes form the paper
unique_in_array2 = np.setdiff1d(essential_genes_array, essential_genes_website_array)

print("Values only in website array:", unique_in_array1)
print("Values only in paper array:", unique_in_array2)