In [4]:
import pandas as pd
import numpy as np
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

In [5]:
def clean_gene(df, column):
    # Copy the dataframe to avoid modifying the original
    df = df.copy()

    # Replace hyphen '-' and '(' ')' with nothing
    df[column] = df[column].str.replace(r'[-\(\)\'\?\_]', '', regex=True)

    # Convert all upper case to lower case
    df[column] = df[column].str.lower()

    return df


In [6]:
val_data = pd.read_csv("Referece_gene_catalogue_resistance_amr_validation.csv")
val_data.head(10)

Unnamed: 0,Gene family,Class,Subclass
0,aac(2')-Ia,AMINOGLYCOSIDE,GENTAMICIN/TOBRAMCYIN
1,aac(2')-Ib,AMINOGLYCOSIDE,GENTAMICIN/TOBRAMCYIN
2,aac(2')-Ic,AMINOGLYCOSIDE,GENTAMICIN/TOBRAMCYIN
3,aac(2')-Id,AMINOGLYCOSIDE,GENTAMICIN/TOBRAMCYIN
4,aac(2')-Ie,AMINOGLYCOSIDE,GENTAMICIN/TOBRAMCYIN
5,aac(2')-IIa,AMINOGLYCOSIDE,KASUGAMYCIN
6,aac(2')-IIb,AMINOGLYCOSIDE,KASUGAMYCIN
7,aac(3)-I,AMINOGLYCOSIDE,GENTAMICIN
8,aac(3)-I,AMINOGLYCOSIDE,GENTAMICIN
9,aac(3)-I,AMINOGLYCOSIDE,GENTAMICIN


In [7]:
val_data.drop('Class', axis=1, inplace=True)
val_data.drop_duplicates(inplace=True)
val_data['Subclass'] = val_data['Subclass'].str.lower()
val_data['Subclass'] = val_data['Subclass'].str.split('/')
val_data = val_data.explode('Subclass')
val_data['Gene family'] = val_data['Gene family'].str.split('/')
val_data = val_data.explode('Gene family')
val_data.drop_duplicates(inplace=True)
val_data.dropna(inplace=True)
val_data = val_data.reset_index()
val_data = clean_gene(val_data,'Gene family')
val_data

Unnamed: 0,index,Gene family,Subclass
0,0,aac2ia,gentamicin
1,0,aac2ia,tobramcyin
2,1,aac2ib,gentamicin
3,1,aac2ib,tobramcyin
4,2,aac2ic,gentamicin
...,...,...,...
1247,6668,vmlr,lincosamide
1248,6668,vmlr,streptogramin
1249,6668,vmlr,tiamulin
1250,6669,vph,viomycin


In [8]:
val_data.to_csv("val_data.csv", index=False)

In [41]:
data = pd.read_csv(r'isolates.csv')
data.head(10)

Unnamed: 0,#Organism group,Isolate,AMR genotypes,AST phenotypes
0,Listeria monocytogenes,PDT000077416.3,"fosX=COMPLETE,lin=COMPLETE","chloramphenicol=S,clindamycin=R,erythromycin=S..."
1,Listeria monocytogenes,PDT000095192.3,"fosX=COMPLETE,lin=COMPLETE","ampicillin=S,penicillin=S"
2,Salmonella enterica,PDT000003687.3,"mdsA=COMPLETE,mdsB=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
3,Salmonella enterica,PDT000003688.4,"mdsA=COMPLETE,mdsB=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
4,Salmonella enterica,PDT000003689.4,"mdsA=COMPLETE,mdsB=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
5,Salmonella enterica,PDT000003690.3,"aph(3'')-Ib=COMPLETE,aph(6)-Id=COMPLETE,mdsA=C...","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
6,Salmonella enterica,PDT000003691.3,"mdsA=COMPLETE,mdsB=COMPLETE,tet(B)=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
7,Salmonella enterica,PDT000003692.3,"mdsA=COMPLETE,mdsB=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
8,Salmonella enterica,PDT000003693.3,"aph(3'')-Ib=COMPLETE,aph(6)-Id=COMPLETE,mdsA=C...","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."
9,Salmonella enterica,PDT000003694.4,"fosA7=COMPLETE,mdsA=COMPLETE,mdsB=COMPLETE","amikacin=S,amoxicillin-clavulanic acid=S,ampic..."


In [42]:
anti_family = pd.read_csv("gene_anti_family.csv")
anti_family.head(10)

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


Unnamed: 0,gene family,gene,antibiotic,drug_class,S,R
0,ANT(2''),ANT(2'')-Ia,spectinomycin,aminoglycoside,,1.0
1,ANT(2''),ANT(2'')-Ia,trimethoprim-sulfamethoxazole,sulfonamide,,
2,ANT(2''),ANT(2'')-Ia,azithromycin,macrolide,,
3,ANT(2''),ANT(2'')-Ia,ceftazidime-avibactam,beta-lactamase,,
4,ANT(2''),ANT(2'')-Ia,piperacillin,Penicillin,,
5,ANT(2''),ANT(2'')-Ia,ticarcillin,Penicillin,,
6,ANT(2''),ANT(2'')-Ia,kanamycin,Aminoglycoside,,1.0
7,ANT(2''),ANT(2'')-Ia,oxacillin,Penicillin,,
8,ANT(2''),ANT(2'')-Ia,tedizolid,Oxazolidinone,,
9,ANT(2''),ANT(2'')-Ia,cefiderocol,beta-lactamase,,


In [43]:
anti_family = clean_gene(anti_family,'gene')
anti_family

Unnamed: 0,gene family,gene,antibiotic,drug_class,S,R
0,ANT(2''),ant2ia,spectinomycin,aminoglycoside,,1.0
1,ANT(2''),ant2ia,trimethoprim-sulfamethoxazole,sulfonamide,,
2,ANT(2''),ant2ia,azithromycin,macrolide,,
3,ANT(2''),ant2ia,ceftazidime-avibactam,beta-lactamase,,
4,ANT(2''),ant2ia,piperacillin,Penicillin,,
...,...,...,...,...,...,...
264379,,mcr2.3,ceftazidime,beta-lactamase,,
264380,,mcr2.3,amoxicillin-clavulanic acid,beta-lactamase,,
264381,,mcr2.3,streptomycin,Aminoglycoside,,
264382,,mcr2.3,moxifloxacin,Quinolone,,


In [44]:
def transform_dataframe(df):
    new_data = []

    for idx, row in df.iterrows():
        # split AMR genotypes and remove "=COMPLETE"
        amr_genotypes = [i.split('=')[0] for i in row['AMR genotypes'].split(',')]

        # split AST phenotypes
        ast_phenotypes = row['AST phenotypes'].split(',')

        # process each phenotype
        for pheno in ast_phenotypes:
            drug, resistance = pheno.split('=')
            if resistance == 'R':
                resistance_score = 1
            elif resistance == 'S':
                resistance_score = 0
            else:  # assuming 'I' as per your description
                resistance_score = 0.5

            # create a new row
            new_row = {'#Organism group': row['#Organism group'],
                       'Isolate': row['Isolate'],
                       'AMR genotypes': ', '.join(amr_genotypes),
                       'drug': drug,
                       'resistance': resistance_score}
            new_data.append(new_row)

    # create a new dataframe
    new_df = pd.DataFrame(new_data)

    return new_df


In [45]:
df = transform_dataframe(data)


In [46]:
df.head(10)

Unnamed: 0,#Organism group,Isolate,AMR genotypes,drug,resistance
0,Listeria monocytogenes,PDT000077416.3,"fosX, lin",chloramphenicol,0.0
1,Listeria monocytogenes,PDT000077416.3,"fosX, lin",clindamycin,1.0
2,Listeria monocytogenes,PDT000077416.3,"fosX, lin",erythromycin,0.0
3,Listeria monocytogenes,PDT000077416.3,"fosX, lin",gentamicin,0.0
4,Listeria monocytogenes,PDT000077416.3,"fosX, lin",levofloxacin,0.0
5,Listeria monocytogenes,PDT000077416.3,"fosX, lin",oxacillin,1.0
6,Listeria monocytogenes,PDT000077416.3,"fosX, lin",penicillin,0.0
7,Listeria monocytogenes,PDT000077416.3,"fosX, lin",rifampin,0.0
8,Listeria monocytogenes,PDT000077416.3,"fosX, lin",tetracycline,0.0
9,Listeria monocytogenes,PDT000077416.3,"fosX, lin",trimethoprim-sulfamethoxazole,0.0


In [47]:
df = clean_gene(df, 'AMR genotypes')
df

Unnamed: 0,#Organism group,Isolate,AMR genotypes,drug,resistance
0,Listeria monocytogenes,PDT000077416.3,"fosx, lin",chloramphenicol,0.0
1,Listeria monocytogenes,PDT000077416.3,"fosx, lin",clindamycin,1.0
2,Listeria monocytogenes,PDT000077416.3,"fosx, lin",erythromycin,0.0
3,Listeria monocytogenes,PDT000077416.3,"fosx, lin",gentamicin,0.0
4,Listeria monocytogenes,PDT000077416.3,"fosx, lin",levofloxacin,0.0
...,...,...,...,...,...
316066,Enterobacter roggenkampii,PDT000898827.2,"blamir, cata, fosa, mcr10.1, oqxa, oqxb",tetracycline,0.0
316067,Enterobacter roggenkampii,PDT000898827.2,"blamir, cata, fosa, mcr10.1, oqxa, oqxb",trimethoprim-sulfamethoxazole,0.0
316068,Enterobacter roggenkampii,PDT001161812.2,"blamir16, cata, fosa, oqxb",ertapenem,0.0
316069,Enterobacter roggenkampii,PDT001161812.2,"blamir16, cata, fosa, oqxb",imipenem,0.5


In [48]:
df.to_csv("BasicData.csv",index=False)

In [514]:
def create_empty_gene_antibiotic_df(df,unique_all_genes,unique_all_antibiotics):
    for idx, row in df.iterrows():
        genotypes = row['AMR genotypes'].split(', ')
        antibiotic = row['drug']

        unique_all_genes.extend(genotypes)
        unique_all_antibiotics.append(antibiotic)

    unique_all_genes = list(set(unique_all_genes))  # Convert to list
    unique_all_antibiotics = list(set(unique_all_antibiotics))  # Convert to list

    gene_antibiotic_df = pd.DataFrame(index=unique_all_genes, columns=unique_all_antibiotics)
    gene_antibiotic_df = gene_antibiotic_df.fillna(-1)

    return gene_antibiotic_df

# Assuming your DataFrame is named 'new_dataframe'
unique_all_genes = []
unique_all_antibiotics = []
gene_antibiotic_df = create_empty_gene_antibiotic_df(df,unique_all_genes,unique_all_antibiotics)
gene_antibiotic_df


Unnamed: 0,chloramphenicol,dicloxacillin,ciprofloxacin,ceftiofur,fosfomycin-glucose-6-phosphate,amoxicillin-clavulanic acid,benzylpenicillin,metronidazole,linezolid,piperacillin,...,trimethoprim-sulfamethoxazole,aztreonam,norfloxacin,neomycin,Imipenem-EDTA-PA,delafloxacin,zoliflodacin,vancomycin,ertapenem,cefiderocol
dfra47,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
glptw355stop,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
aph6,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blapdc,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blaper1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
tet39,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blalap,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blavim2,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blaadc120,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1


In [515]:
gene_antibiotic_df.to_csv("Empty_Gene_Antibiotic.csv", index=True, index_label=False)

***********************************************

### run from here

In [34]:
# def gene_per_drug(drug):
#     drug_df = df[df['drug'] == drug]
#     unique_genes = set()
#     for genes in drug_df['AMR genotypes'].str.split(', '):
#         unique_genes.update(genes)
#     return drug_df, unique_genes

In [49]:
df = pd.read_csv('BasicData.csv')
val_data = pd.read_csv('val_data.csv')
gene_antibiotic_df = pd.read_csv('Empty_Gene_Antibiotic.csv')
anti_family = clean_gene(anti_family,'gene')

In [50]:
gene_antibiotic_df

Unnamed: 0,chloramphenicol,dicloxacillin,ciprofloxacin,ceftiofur,fosfomycin-glucose-6-phosphate,amoxicillin-clavulanic acid,benzylpenicillin,metronidazole,linezolid,piperacillin,...,trimethoprim-sulfamethoxazole,aztreonam,norfloxacin,neomycin,Imipenem-EDTA-PA,delafloxacin,zoliflodacin,vancomycin,ertapenem,cefiderocol
dfra47,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
glptw355stop,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
aph6,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blapdc,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blaper1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
tet39,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blalap,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blavim2,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
blaadc120,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1


In [51]:
def gene_per_drug(drug):
    drug_df = df[df['drug'] == drug]
    unique_genes = set()
    for genes in drug_df['AMR genotypes'].str.split(', '):
        unique_genes.update(genes)

    # Get unique #Organism group values
    organism_groups = drug_df['#Organism group'].unique()

    # Initialize empty lists for train_df and test_df
    train_dfs = []
    test_dfs = []

    # Iterate over #Organism groups
    for group in organism_groups:
        group_df = drug_df[drug_df['#Organism group'] == group]

        if len(group_df) > 1:
            group_train_df, group_test_df = train_test_split(group_df, test_size=0.2, random_state=42)
        else:
            group_train_df = group_df
            group_test_df = group_df

        train_dfs.append(group_train_df)
        test_dfs.append(group_test_df)

    # Concatenate the train and test dataframes for all groups
    train_df = pd.concat(train_dfs)
    test_df = pd.concat(test_dfs)

    return train_df, test_df, unique_genes

# Rest of the code remains the same


In [52]:
def relevant_val_df(drug,unique_genes):
    relevant_df = val_data[(val_data['Subclass'] == drug) & (val_data['Gene family'].isin(unique_genes))]
    if relevant_df.shape[0]>1 :
        val_train, val_test = train_test_split(relevant_df, test_size=0.2, random_state=42)
        return val_train, val_test
    return None, None

In [53]:
def create_gene_df(drug_df,val_df, unique_genes):
    gene_arrays = []
    labels = []

    for idx, row in drug_df.iterrows():
        gene_array = np.zeros(len(unique_genes), dtype=int)
        genes = row['AMR genotypes'].split(', ')
        for gene in genes:
            gene_index = list(unique_genes).index(gene)
            gene_array[gene_index] = 1
            
        gene_arrays.append(gene_array.tolist())
        labels.append(row['resistance'])

    gene_df = pd.DataFrame(gene_arrays, columns=unique_genes)
    gene_df['label'] = labels

    if val_df is not None:
        for idx, row in val_df.iterrows():
            gene_array = np.zeros(len(unique_genes), dtype=int)
            gene = row['Gene family']
            gene_index = list(unique_genes).index(gene)
            gene_array[gene_index] = 1

            gene_arrays.append(gene_array.tolist())
            labels.append(1)

        gene_df = pd.DataFrame(gene_arrays, columns=unique_genes)
        gene_df['label'] = labels

    return gene_df.sample(frac=1).reset_index(drop=True)

In [64]:
def create_update_gene_df(drug_df,val_df, unique_genes, drug, zero_col):
    gene_arrays = []
    labels = []
    drug_r = anti_family.loc[(anti_family['antibiotic'] == drug) & (anti_family['R'] == 1), 'gene'].tolist()
    drug_s = anti_family.loc[(anti_family['antibiotic'] == drug) & (anti_family['R'] != 1) & (anti_family['S'] == 'S'), 'gene'].tolist()

    for idx, row in drug_df.iterrows():
        gene_array = np.zeros(len(unique_genes), dtype=int)
        genes = row['AMR genotypes'].split(', ')
        label = row['resistance']

        for gene in genes:
            gene_index = list(unique_genes).index(gene)
            gene_array[gene_index] = 1
            
#             if (gene not in drug_r) and (label == 1):
#                 # Set the value to 0 if the condition is met
#                 gene_array[gene_index] = 0

#             # Check if the gene is in the list of drug genes and label is 0
#             elif (gene in drug_r) and (label == 0):
#                 gene_array[gene_index] = 0
            
#             if (gene in drug_s) and (label == 0):
#                 gene_array[gene_index] = 1
                
        gene_arrays.append(gene_array.tolist())
        labels.append(row['resistance'])

    gene_df = pd.DataFrame(gene_arrays, columns=unique_genes)
    gene_df['label'] = labels

    if val_df is not None:
        for idx, row in val_df.iterrows():
            gene_array = np.zeros(len(unique_genes), dtype=int)
            gene = row['Gene family']
            gene_index = list(unique_genes).index(gene)
            gene_array[gene_index] = 1

            gene_arrays.append(gene_array.tolist())
            labels.append(1)

        gene_df = pd.DataFrame(gene_arrays, columns=unique_genes)
        gene_df['label'] = labels
        
    if len(zero_col)==0:
        zero_col = gene_df.columns[gene_df.eq(0).all(axis=0)]
    gene_df = gene_df.drop(zero_col, axis=1)

    return gene_df.sample(frac=1).reset_index(drop=True), zero_col

In [55]:
def update_df(df, drug, anti_family, zero_col):
    # Get the list of genes for which the antibiotic is the specified drug
    drug_r = anti_family.loc[(anti_family['antibiotic'] == drug) & (anti_family['R'] == 1), 'gene'].tolist()
    drug_s = anti_family.loc[(anti_family['antibiotic'] == drug) & (anti_family['R'] != 1) & (anti_family['S'] == 's'), 'gene'].tolist()

    # Iterate over each row
    for index, row in df.iterrows():
        # Iterate over each column
        for column in df.columns:
            gene = str(column)
            if gene == "label":
                continue
            # Check if the gene is not in the list of drug genes and label is 1 or 0.5
            if ((gene not in drug_r) or (gene in drug_s)) and (row['label'] == 1):
                # Set the value to 0 if the condition is met
                df.loc[index, column] = 0

            # Check if the gene is in the list of drug genes and label is 0
            if (gene in drug_r) and (row['label'] == 0):
                df.loc[index, column] = 0
            
            if (gene in drug_s) and (row['label'] == 0):
                df.loc[index, column] = 1
           
                
                
    if len(zero_col)==0:
        zero_col = df.columns[df.eq(0).all(axis=0)]
    df = df.drop(zero_col, axis=1)
    return df, zero_col


In [56]:
def Model_deep(drug_train_for_model, drug_test_for_model, epochs, batch_size):
    # Assuming your DataFrame with gene features and labels is named `oxacillin_df_for_model`
    # Split the data into training features (genes) and labels
    X_train = drug_train_for_model.drop('label', axis=1).values
    y_train = drug_train_for_model['label'].values

    # Split the data into testing features (genes) and labels
    X_test = drug_test_for_model.drop('label', axis=1).values
    y_test = drug_test_for_model['label'].values

    # Define the neural network model
    model = keras.Sequential()
    model.add(keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)))
    model.add(keras.layers.Dense(16, activation='relu'))
    model.add(keras.layers.Dense(1, activation='sigmoid'))  # Sigmoid activation for probability between 0 and 1

    # Compile the model
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['mse'])

    # Train the model
    model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size)

    # Evaluate the model
    loss, mse = model.evaluate(X_test, y_test)
    pred = model.predict(X_test)
    print('Loss:', loss)
    print('Mean Squared Error:', mse)
    print('len of pred:', pred.shape, 'y_test:', y_test.shape)

    # Make predictions
    return y_test, pred, model


In [57]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

def Model(drug_train_for_model, drug_test_for_model):
    # Split the data into training features (genes) and labels
    X_train = drug_train_for_model.drop('label', axis=1).values
    y_train = drug_train_for_model['label'].values

    # Split the data into testing features (genes) and labels
    X_test = drug_test_for_model.drop('label', axis=1).values
    y_test = drug_test_for_model['label'].values

    # Initialize the model
    model = RandomForestRegressor(n_estimators=100, random_state=0)

    # Train the model
    model.fit(X_train, y_train)

    # Predict the resistance score
    y_pred = model.predict(X_test)

    # Compute mean squared error
    mse = mean_squared_error(y_test, y_pred)

    print('Mean Squared Error:', mse)
    print('len of pred:', y_pred.shape, 'y_test:', y_test.shape)

    # Make predictions
    return y_test, y_pred, model

In [58]:
def evaluation(y_test, predictions, drug):
    # Calculate mean squared error
    mse = mean_squared_error(y_test, predictions)

    # Calculate mean absolute error
    mae = mean_absolute_error(y_test, predictions)

    # Calculate R^2 score
    r2 = r2_score(y_test, predictions)

    
    print('Mean Squared Error:', mse)
    print('Mean Absolute Error:', mae)
    print('R^2 Score:', r2)
    return mse, mae, r2


In [59]:
def Create_df_one_gene(drug_df_for_model):
    # Extract the column names (excluding the label column)
    column_names = drug_df_for_model.columns[:-1]

    # Create a new DataFrame with zeros
    df_one_gene = pd.DataFrame(0, index=np.arange(len(column_names)), columns=column_names)

    # Set the value at the corresponding index position in each row
    for i in range(len(column_names)):
        df_one_gene.iloc[i, i] = 1

    # Display the new DataFrame
    return df_one_gene

In [60]:
def PredR_Antibiotic(drug):
    print('************************\n', drug)
    zero_lst = []
    train_df, test_df,drug_genes =   gene_per_drug(drug)
    val_train, val_test= relevant_val_df(drug, drug_genes)
    drug_train_for_model, zero_lst = create_update_gene_df(train_df, val_train, list(drug_genes), drug, zero_lst)
    drug_test_for_model, zero_lst = create_update_gene_df(test_df, val_test, list(drug_genes), drug, zero_lst)
    
#     drug_train_for_model,zero_lst = update_df(drug_train_for_model, drug, anti_family,zero_lst)
#     drug_test_for_model,zero_lst = update_df(drug_test_for_model, drug , anti_family,zero_lst)
    # Combine train and test dataframes
    combined_df = pd.concat([drug_train_for_model, drug_test_for_model], ignore_index=True)

    y_test, predictions, model = Model(drug_train_for_model, drug_test_for_model, 20, 10)
    mse, mae, r2 = evaluation(y_test, predictions, drug)
    df_gene = Create_df_one_gene(combined_df)
    pred = model.predict(df_gene)
    for i, col in enumerate(df_gene.columns):
        gene_antibiotic_df.loc[col, drug] = pred[i]
    return r2

In [65]:
print('************************\n', 'ciprofloxacin')
drug = 'ciprofloxacin'
drug_s = anti_family.loc[(anti_family['antibiotic'] == drug) & (anti_family['R'] != 1) & (anti_family['S'] == 'S'), 'gene'].tolist()

zero_lst = []
train_df, test_df,drug_genes =   gene_per_drug('ciprofloxacin')
val_train, val_test= relevant_val_df('ciprofloxacin', drug_genes)
drug_train_for_model, zero_lst = create_update_gene_df(train_df, val_train, list(drug_genes), 'ciprofloxacin', zero_lst)
drug_test_for_model, zero_lst = create_update_gene_df(test_df, val_test, list(drug_genes), 'ciprofloxacin', zero_lst)

#     drug_train_for_model,zero_lst = update_df(drug_train_for_model, drug, anti_family,zero_lst)
#     drug_test_for_model,zero_lst = update_df(drug_test_for_model, drug , anti_family,zero_lst)
# Combine train and test dataframes
combined_df = pd.concat([drug_train_for_model, drug_test_for_model], ignore_index=True)

y_test, predictions, model = Model_deep(drug_train_for_model, drug_test_for_model,20,10)
mse, mae, r2 = evaluation(y_test, predictions, 'ciprofloxacin')
df_gene = Create_df_one_gene(combined_df)
pred = model.predict(df_gene)

for i, col in enumerate(df_gene.columns):
    if col in drug_s:
        gene_antibiotic_df.loc[col, 'ciprofloxacin'] = 0.0
    else:
        gene_antibiotic_df.loc[col, 'ciprofloxacin'] = pred[i]

************************
 ciprofloxacin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.13263952732086182
Mean Squared Error: 0.017189996317029
len of pred: (4270, 1) y_test: (4270,)
Mean Squared Error: 0.01718999576035255
Mean Absolute Error: 0.029224673994173035
R^2 Score: 0.8970451312099363


In [66]:
ciprofloxacin_pred = pd.DataFrame(data=gene_antibiotic_df['ciprofloxacin'])
ciprofloxacin_pred.to_csv('ciprofloxacin_pred_1.csv', index=True)

In [67]:
drug_train_for_model.to_csv('drug_train_for_model_1.csv', index=True)

In [464]:
drug_train_for_model[drug_train_for_model['blaact9'] == 1]

4.0

In [433]:
antibiotic_few_samples = []
antibiotic_precent = {} 
for antibiotic in unique_all_antibiotics:
    if len(df[df['drug'] == antibiotic]) < 50:
        antibiotic_few_samples.append(antibiotic)
        continue   
    r2 = PredR_Antibiotic(antibiotic)
    antibiotic_precent[antibiotic] = r2

************************
 chloramphenicol
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.04142970219254494
Mean Squared Error: 0.006413301918655634
len of pred: (2467, 1) y_test: (2467,)
Mean Squared Error: 0.006413300873555611
Mean Absolute Error: 0.015188285626479201
R^2 Score: 0.904324149670955
************************
 clindamycin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.01982947811484337
Mean Squared Error: 0.0037646470591425896
len of pred: (741, 1) y_test: (741,)
Mean Squared Error: 0.003764647136750669
Mean Absolute Error: 0.008050381121189148
R^2 Score: 0.9581800739554267
**********************

Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.0424027256667614
Mean Squared Error: 0.0104377381503582
len of pred: (172, 1) y_test: (172,)
Mean Squared Error: 0.010437738357735543
Mean Absolute Error: 0.02836709374618461
R^2 Score: 0.8441241536722623
************************
 tetracycline
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.03142552077770233
Mean Squared Error: 0.003840138204395771
len of pred: (4091, 1) y_test: (4091,)
Mean Squared Error: 0.0038401391841515913
Mean Absolute Error: 0.013510844061122598
R^2 Score: 0.9842833635632058
************************
 trimethoprim-sulfamethoxazole
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20

Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.09579049795866013
Mean Squared Error: 0.011664372868835926
len of pred: (2942, 1) y_test: (2942,)
Mean Squared Error: 0.011664371701869026
Mean Absolute Error: 0.04589682262709395
R^2 Score: 0.9005782735109045
************************
 ampicillin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.016375767067074776
Mean Squared Error: 0.0013492284342646599
len of pred: (3290, 1) y_test: (3290,)
Mean Squared Error: 0.0013492283666049815
Mean Absolute Error: 0.0051542585884246815
R^2 Score: 0.9943328215850792
************************
 cefoxitin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Ep

Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.04052753001451492
Mean Squared Error: 0.004310470074415207
len of pred: (4183, 1) y_test: (4183,)
Mean Squared Error: 0.004310469548825982
Mean Absolute Error: 0.014646712349151754
R^2 Score: 0.9663946914407351
************************
 kanamycin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.007658082526177168
Mean Squared Error: 0.0006281728274188936
len of pred: (779, 1) y_test: (779,)
Mean Squared Error: 0.0006281728528140305
Mean Absolute Error: 0.003411637983772987
R^2 Score: 0.9945088511808455
************************
 nalidixic acid
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20

Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.04561665654182434
Mean Squared Error: 0.005064842291176319
len of pred: (1674, 1) y_test: (1674,)
Mean Squared Error: 0.005064842568026624
Mean Absolute Error: 0.016664249769442866
R^2 Score: 0.9412051086488841
************************
 amoxicillin-clavulanic acid
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.09569113701581955
Mean Squared Error: 0.01163542177528143
len of pred: (2942, 1) y_test: (2942,)
Mean Squared Error: 0.011635424435692505
Mean Absolute Error: 0.046698662205938414
R^2 Score: 0.9008250066615591
************************
 ampicillin
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20

Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.045172203332185745
Mean Squared Error: 0.005469997879117727
len of pred: (4270, 1) y_test: (4270,)
Mean Squared Error: 0.005469998592272552
Mean Absolute Error: 0.02048474797627612
R^2 Score: 0.9672389106314879
************************
 gentamicin


KeyboardInterrupt: 

In [439]:
train_df, test_df,drug_genes =   gene_per_drug('ciprofloxacin')
val_train, val_test= relevant_val_df('ciprofloxacin', drug_genes)
drug_train_for_model = create_gene_df(train_df, val_train, list(drug_genes))
drug_test_for_model = create_gene_df(test_df, val_test, list(drug_genes))
zero_lst = []
drug_train_for_model,zero_lst = update_df(drug_train_for_model, 'ciprofloxacin', anti_family,zero_lst)
drug_test_for_model,zero_lst = update_df(drug_test_for_model, 'ciprofloxacin' , anti_family,zero_lst)
# Combine train and test dataframes
combined_df = pd.concat([drug_train_for_model, drug_test_for_model], ignore_index=True)

y_test, predictions, model = Model(drug_train_for_model, drug_test_for_model, 20, 10)
mse, mae, r2 = evaluation(y_test, predictions, 'ciprofloxacin')
df_gene = Create_df_one_gene(combined_df)
pred = model.predict(df_gene)
for i, col in enumerate(df_gene.columns):
    gene_antibiotic_df.loc[col, 'ciprofloxacin'] = pred[i]

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Loss: 0.04518590867519379
Mean Squared Error: 0.005301623605191708
len of pred: (4270, 1) y_test: (4270,)
Mean Squared Error: 0.005301622533409055
Mean Absolute Error: 0.018122820352492096
R^2 Score: 0.968247353873089


In [453]:
gene_antibiotic_df_copy

Unnamed: 0,chloramphenicol,dicloxacillin,ciprofloxacin,ceftiofur,fosfomycin-glucose-6-phosphate,amoxicillin-clavulanic acid,benzylpenicillin,metronidazole,linezolid,piperacillin,...,trimethoprim-sulfamethoxazole,aztreonam,norfloxacin,neomycin,Imipenem-EDTA-PA,delafloxacin,zoliflodacin,vancomycin,ertapenem,cefiderocol
dfra47,-1.0,-1,-1.000000,-1.0,-1,-1.0,-1,-1,-1,-1,...,-1.000000,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
glptw355stop,-1.0,-1,-1.000000,-1.0,-1,-1.0,-1,-1,-1,-1,...,0.904007,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
aph6,-1.0,-1,-1.000000,-1.0,-1,-1.0,-1,-1,-1,-1,...,-1.000000,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
blapdc,-1.0,-1,0.963507,-1.0,-1,-1.0,-1,-1,-1,-1,...,-1.000000,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
blaper1,-1.0,-1,-1.000000,-1.0,-1,-1.0,-1,-1,-1,-1,...,0.906040,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
tet39,-1.0,-1,0.960762,-1.0,-1,-1.0,-1,-1,-1,-1,...,0.865247,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
blalap,-1.0,-1,-1.000000,-1.0,-1,-1.0,-1,-1,-1,-1,...,0.899488,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
blavim2,-1.0,-1,0.960044,-1.0,-1,-1.0,-1,-1,-1,-1,...,-1.000000,-1,-1,-1,-1,-1,-1,-1.0,-1,-1
blaadc120,-1.0,-1,0.960444,-1.0,-1,-1.0,-1,-1,-1,-1,...,-1.000000,-1,-1,-1,-1,-1,-1,-1.0,-1,-1


In [454]:

gene_antibiotic_df.to_csv("gene_antibiotic_df_precent_3.csv",index=True)

In [435]:
antibiotic_few_samples

[]

In [436]:
antibiotic_precent

{'chloramphenicol': 0.9042296916217354,
 'clindamycin': 0.9581800739554267,
 'erythromycin': 0.9949654393820588,
 'gentamicin': 0.9663946914407351,
 'levofloxacin': 0.9378462191488209,
 'oxacillin': 0.9999999405766836,
 'penicillin': 0.8718455070977329,
 'rifampin': 0.8441241536722623,
 'tetracycline': 0.9837762134825676,
 'trimethoprim-sulfamethoxazole': 0.9938500269472219,
 'vancomycin': 0.9822547179054787,
 'ampicillin': 0.994315944264865,
 'amikacin': 0.9412051086488841,
 'amoxicillin-clavulanic acid': 0.9008250066615591,
 'cefoxitin': 0.9666059370845,
 'ceftiofur': 0.9695972882381078,
 'ceftriaxone': 0.9844440077197201,
 'ciprofloxacin': 0.9672389106314879,
 'kanamycin': 0.9945088511808455,
 'nalidixic acid': 0.9906525365550776,
 'streptomycin': 0.9963276534219546,
 'sulfamethoxazole': 0.994082975984543}

### model - accuraccy & MSE

In [32]:
predictions

array([[0.00322947],
       [0.9983469 ],
       [0.00322947],
       ...,
       [0.00265715],
       [0.00217324],
       [0.0024083 ]], dtype=float32)

### Evaluation - continues (MSE)

Mean Squared Error: 0.01619927366086918
Mean Absolute Error: 0.03154599644390347
R^2 Score: 0.8993035344650906


### one examle test

In [17]:
y_test[1]

1.0

In [18]:
model.predict(X_test[1].reshape(1, -1))

array([[0.99743843]], dtype=float32)

### Test for each Gene seperate

In [466]:


anti_family['R'].sum()


38985.0

In [34]:
pred_gene = model.predict(df_one_gene)
data = pd.DataFrame(data=pred_gene)

In [50]:
data

Unnamed: 0,0
0,0.062887
1,0.002434
2,0.015267
3,0.001975
4,0.001146
...,...
1099,0.016005
1100,0.001224
1101,0.049610
1102,0.007509


In [69]:
genes_list = list(oxacillin_genes)

# Access the element at index 29
gene_at_index = genes_list[1]

print(gene_at_index)

blaADC-155


In [48]:
data = pd.DataFrame(data=pred_gene)
data[data > 0.5].dropna()

Unnamed: 0,0
8,0.642578
12,0.824058
29,0.997132
42,0.636477
90,0.988201
95,0.888557
146,0.908533
150,0.604366
195,0.978971
227,0.639507


In [36]:
pred_gene.max()

0.9973835

In [37]:
pred_gene.min()

1.4436484e-05

In [None]:
# going over all the bacteria that are resistant to "chloramphenicol":
for gene_lst in df[(df.drug == "ciprofloxacin") & (df.resistance == 0)]['AMR genotypes'].values:
# going over all the genes of each bacteria :
    for g in gene_lst.split():
        print(g)
        v = np.zeros(X.shape[1])
        v[feature_names.index(g)]=1
        print(classifier.predict_proba([v]))
    print("***************************")