In [None]:
import pandas as pd
from combat.pycombat import pycombat

# -----------------------------
# Helper function to load and split samples
# -----------------------------
def load_gse_series(gse_id):
    """
    Loads TNBC and non-TNBC sample IDs and expression matrix for the given GSE series.
    Returns: TNBC_df, NonTNBC_df
    """
    # Read sample lists
    tnbc_samples = pd.read_csv(f'tnbc_samples_{gse_id}.csv').iloc[:, 0].tolist()
    non_tnbc_samples = pd.read_csv(f'non_tnbc_samples_{gse_id}.csv').iloc[:, 0].tolist()

    # Read expression matrix
    df = pd.read_csv(f'{gse_id}.csv')
    df.set_index('Gene Symbol', inplace=True)
    df.index.name = None
    df = df.iloc[1:]  # remove annotation row if present

    # Split into TNBC and non-TNBC matrices
    df_tnbc = df[tnbc_samples]
    df_non = df[non_tnbc_samples]

    return df_tnbc, df_non


# -----------------------------
# Load all GSE datasets
# -----------------------------
GSE76275_TNBC, GSE76275_NonTNBC = load_gse_series('GSE76275')
GSE95700_TNBC, GSE95700_NonTNBC = load_gse_series('GSE95700')
GSE65194_TNBC, GSE65194_NonTNBC = load_gse_series('GSE65194')
GSE18864_TNBC, GSE18864_NonTNBC = load_gse_series('GSE18864')

# GSE58812 and GSE83937 do not have non-TNBC samples
GSE58812 = pd.read_csv('GSE58812.csv')
GSE58812.set_index('Gene Symbol', inplace=True)
GSE58812.index.name = None
GSE58812 = GSE58812.iloc[1:]

GSE83937 = pd.read_csv('GSE83937.csv')
GSE83937.set_index('Gene Symbol', inplace=True)
GSE83937.index.name = None
GSE83937 = GSE83937.iloc[1:]


# -----------------------------
# Combine TNBC datasets
# -----------------------------
tnbc_dataframes = [
    GSE76275_TNBC, GSE95700_TNBC, GSE65194_TNBC,
    GSE18864_TNBC, GSE58812, GSE83937
]

tnbc_combined = pd.concat(tnbc_dataframes, axis=1)
tnbc_combined.to_csv('final.csv')


# -----------------------------
# Combine non-TNBC datasets
# -----------------------------
non_tnbc_dataframes = [
    GSE76275_NonTNBC, GSE95700_NonTNBC,
    GSE65194_NonTNBC, GSE18864_NonTNBC
]

non_tnbc_combined = pd.concat(non_tnbc_dataframes, axis=1)
non_tnbc_combined.to_csv('final_non.csv')


# -----------------------------
# Check duplicated sample names
# -----------------------------
duplicate_sample_names = []

for df in non_tnbc_dataframes + tnbc_dataframes:
    duplicates = df.columns[df.columns.duplicated()]
    duplicate_sample_names.extend(duplicates.tolist())

print("Duplicate Sample Names:", duplicate_sample_names)


# -----------------------------
# Batch Correction with ComBat
# -----------------------------
# Combine all matrices (TNBC + non-TNBC)
combined_all = pd.concat(tnbc_dataframes + non_tnbc_dataframes, axis=1)

# Assign batch labels
batch_labels = []
batch_ids = [1, 2, 3, 4, 5, 6]  # For six TNBC sets

for i, df in enumerate(tnbc_dataframes):
    batch_labels.extend([batch_ids[i]] * df.shape[1])

for i, df in enumerate(non_tnbc_dataframes):
    batch_labels.extend([batch_ids[i]] * df.shape[1])  # mirror TNBC batches

batch_labels = pd.Series(batch_labels, index=combined_all.columns, name="Batch")

# Apply ComBat
batch_corrected = pycombat(combined_all, batch_labels)

# -----------------------------
# Extract corrected TNBC / non-TNBC matrices
# -----------------------------
tnbc_samples = [col for df in tnbc_dataframes for col in df]
non_tnbc_samples = [col for df in non_tnbc_dataframes for col in df]

tnbc_corrected = batch_corrected[tnbc_samples]
non_tnbc_corrected = batch_corrected[non_tnbc_samples]

# Save outputs
tnbc_corrected.to_csv('final.csv')
non_tnbc_corrected.to_csv('final_non.csv')

print("Data preparation and ComBat batch correction completed successfully.")