In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import umap
from sklearn.decomposition import PCA
import numpy as np
import copy as cp

In [None]:
# Load the clinical data
clinical = pd.read_csv('../datasets_TCGA/merged/merged_clinical.csv')

In [None]:
clinical.shape

In [None]:
to_delete = pd.read_csv('../datasets_TCGA/summary_removed_<2_omics_TCGA.tsv', sep='\t')
to_keep = pd.read_csv('../datasets_TCGA/summary_omics_>1_TCGA.tsv', sep='\t')

In [None]:
to_delete.shape

In [None]:
to_keep_ids = to_keep['sample_id'].astype(str).values
to_keep_cancer_types = to_keep['cancertype'].astype(str).values

In [None]:
# Filter out the patients or samples with less than 2 omics 
clinical_filtered = clinical[clinical['sample_id'].astype(str).isin(to_keep_ids)].reset_index(drop=True)

In [None]:
clinical_filtered.shape

In [None]:
# Merge cancertype into clinical_filtered based on sample_id
merged = clinical_filtered.merge(
    to_keep[['sample_id', 'cancertype']],
    on='sample_id',
    how='left'
)

cols = list(merged.columns)
if 'sample_id' in cols and 'cancertype' in cols:
    cols.insert(cols.index('sample_id') + 1, cols.pop(cols.index('cancertype')))
    merged = merged[cols]

clinical_filtered = merged

In [None]:
clinical_filtered.head()

In [None]:
#[(clinical['gender.demographic'] != 'female') & (clinical['gender.demographic'] !='male')]


# keeping the relevant variables
keep_columns = ['sample_id',
                'cancertype',
                'gender.demographic', 
                'vital_status.demographic',
                'age_at_index.demographic',
                'ajcc_pathologic_stage.diagnoses', 
                'primary_diagnosis.diagnoses',
                'ajcc_pathologic_t.diagnoses',
                'ajcc_pathologic_n.diagnoses',
                'ajcc_pathologic_m.diagnoses',
                'tissue_type.samples',
                ]

clinical_filtered = clinical_filtered[keep_columns]


# Select the relevant columns (assuming cancertype is already merged in)
cols_to_check = [col for col in clinical_filtered.columns if col != 'sample_id']
cols_to_check = [col for col in cols_to_check if col != 'cancertype']

# Group by cancertype and count NaNs
nan_counts = clinical_filtered.groupby('cancertype')[cols_to_check].apply(lambda df: df.isna().sum())

# Plot
nan_counts.plot(kind='bar', figsize=(14, 6), width=0.85)

plt.title('Number of NaNs per Cancer Type per Column')
plt.ylabel('Number of Missing Values')
plt.xlabel('Cancer Type')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title="Variable", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.show()

In [None]:
# The majority of the NaNs are in tumors mutually exclusive in terms of gender, so they can be inferred. 
clinical_filtered['gender.demographic'].value_counts()

In [None]:
clinical_filtered['gender.demographic'].isna().sum()

In [None]:
# Gender.demographic based on cancer type
sex_map = {
    'OV': 'female',
    'UCEC': 'female',
    'CESC': 'female',
    'PRAD': 'male',
    'TGCT': 'male',
}

# Fill in gender where it's missing and the cancer type is in the map
mask = clinical_filtered['gender.demographic'].isna() & clinical_filtered['cancertype'].isin(sex_map.keys())
clinical_filtered.loc[mask, 'gender.demographic'] = clinical_filtered.loc[mask, 'cancertype'].map(sex_map)

# Compute NaN counts per cancer type
cols_to_check = [col for col in keep_columns if col not in ['sample_id', 'cancertype']]
nan_counts = clinical_filtered.groupby('cancertype')[cols_to_check].apply(lambda df: df.isna().sum())

# Plot
nan_counts.plot(kind='bar', figsize=(14, 6), width=0.85)
plt.title('Number of NaNs per Cancer Type per Column (gender imputed biologically)')
plt.ylabel('Number of Missing Values')
plt.xlabel('Cancer Type')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title="Variable", bbox_to_anchor=(1.01, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.show()

In [None]:
# Since the "Not Reported are only 13, I decided to put them in Nan, and then decide what to do with them later.
clinical_filtered['vital_status.demographic'].value_counts()

In [None]:
clinical_filtered["vital_status.demographic"].isna().sum()

In [None]:
clinical_filtered["vital_status.demographic"] = clinical_filtered["vital_status.demographic"].replace("Not Reported", pd.NA)

In [None]:
clinical_filtered["vital_status.demographic"].isna().sum()

In [None]:
# Kept as it is. Na kept at the moment. 
clinical_filtered['age_at_index.demographic'].value_counts()

In [None]:
clinical_filtered["age_at_index.demographic"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['age_at_index.demographic'].value_counts().sum()

In [None]:
# Not included, too many Nans, and the class is too complex
clinical_filtered['ajcc_pathologic_stage.diagnoses'].value_counts()

In [None]:
clinical_filtered["ajcc_pathologic_stage.diagnoses"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['ajcc_pathologic_stage.diagnoses'].value_counts().sum()

In [None]:
clinical_filtered['primary_diagnosis.diagnoses'].value_counts()

In [None]:
clinical_filtered["primary_diagnosis.diagnoses"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['primary_diagnosis.diagnoses'].value_counts().sum()

In [None]:
diag = clinical_filtered['primary_diagnosis.diagnoses'].astype(str)

def collapse_dx(x):
    if x == 'nan':
        return np.nan                          # keep missing distinct
    x_low = x.lower()
    if 'carcinoma' in x_low:
        return 'Carcinoma'
    if 'melanoma' in x_low:
        return 'Melanoma'
    if 'sarcoma' in x_low:
        return 'Sarcoma'
    if any(term in x_low for term in ['germ cell', 'teratoma', 'seminoma', 'yolk sac', 'embryonal']):
        return 'Germ-cell'
    return 'Other'


clinical_filtered_ = cp.deepcopy(clinical_filtered)
clinical_filtered_['dx_group'] = diag.map(collapse_dx)

# Quick sanity check
clinical_filtered_['dx_group'].value_counts(dropna=False)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# 1. Total dataset size
total = len(clinical_filtered_)
print(f"Total samples after filtering: {total}")

# 2. Patients per class (only the 5 groups, drop NaN)
counts = clinical_filtered_['dx_group'].value_counts(dropna=True)
print("\nPatients per diagnostic super-class:")
print(counts)

# 3. Scatter plot
#   x positions 0–4, y = counts; color‐coded by class index
x = np.arange(len(counts))
y = counts.values

plt.figure(figsize=(8, 5))
plt.scatter(x, y, c=x)               # map each class to a different color automatically
plt.xticks(x, counts.index, rotation=45, ha='right')
plt.ylabel("Number of patients")
plt.title("Patient counts by diagnostic super-class")
plt.tight_layout()
plt.show()


In [None]:
# Check the classes
n_variables = 30
rnaseq = pd.read_csv(f'../datasets_TCGA/merged/reduced_rnaseq_{n_variables}.csv')

In [None]:
rnaseq.shape

In [None]:
# Only set the index if it's not already set
if 'sample_id' in rnaseq.columns:
    rnaseq = rnaseq.set_index("sample_id")

# Likewise for the metadata
meta = clinical_filtered_.set_index('sample_id')['dx_group'].dropna()
# meta = clinical_filtered_['dx_group']
# meta = meta[meta.notna() & (meta != 'Other')]

# Align metadata and expression data
common_ids = meta.index.intersection(rnaseq.index)
rnaseq_sub = rnaseq.loc[common_ids]
dx_groups = meta.loc[common_ids].astype('category')

In [None]:
# 4. PCA
pca = PCA(n_components=2)
coords = pca.fit_transform(rnaseq_sub.values)

In [None]:
# 5. Plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(coords[mask, 0], coords[mask, 1], s=20, label=cat)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA of RNA-seq by Diagnostic Group')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# 6. UMAP

umap_model = umap.UMAP(n_components=2, random_state=42)
umap_coords = umap_model.fit_transform(rnaseq_sub.values)



In [None]:
# 7. UMAP Plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(umap_coords[mask, 0], umap_coords[mask, 1], s=2, alpha = .7, label=cat)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title('UMAP of RNA-seq by Diagnostic Group')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
diag = clinical_filtered['primary_diagnosis.diagnoses'].astype(str)

def collapse_dx(x):
    if x == 'nan':
        return np.nan                          # keep missing distinct
    x_low = x.lower()
    if 'carcinoma' in x_low:
        return 'Carcinoma'
    else:
        return 'non_carcinoma'
    return 'Other'


clinical_filtered_ = cp.deepcopy(clinical_filtered)
clinical_filtered_['dx_group'] = diag.map(collapse_dx)

# Quick sanity check
clinical_filtered_['dx_group'].value_counts(dropna=False)

In [None]:
# 1. Total dataset size
total = len(clinical_filtered_)
print(f"Total samples after filtering: {total}")

# 2. Patients per class (only the 5 groups, drop NaN)
counts = clinical_filtered_['dx_group'].value_counts(dropna=True)
print("\nPatients per diagnostic super-class:")
print(counts)

# 3. Scatter plot
#   x positions 0–4, y = counts; color‐coded by class index
x = np.arange(len(counts))
y = counts.values

plt.figure(figsize=(8, 5))
plt.scatter(x, y, c=x)               # map each class to a different color automatically
plt.xticks(x, counts.index, rotation=45, ha='right')
plt.ylabel("Number of patients")
plt.title("Patient counts by diagnostic super-class")
plt.tight_layout()
plt.show()

In [None]:
# Only set the index if it's not already set
if 'sample_id' in rnaseq.columns:
    rnaseq = rnaseq.set_index("sample_id")

# Likewise for the metadata
meta = clinical_filtered_.set_index('sample_id')['dx_group'].dropna()

# Align metadata and expression data
common_ids = meta.index.intersection(rnaseq.index)
rnaseq_sub = rnaseq.loc[common_ids]
dx_groups = meta.loc[common_ids].astype('category')

In [None]:
# 4. PCA
pca = PCA(n_components=2)
coords = pca.fit_transform(rnaseq_sub.values)

In [None]:
# 5. Plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(coords[mask, 0], coords[mask, 1], s=20, label=cat)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA of RNA-seq by 2 Groups')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# 6. UMAP

umap_model = umap.UMAP(n_components=2, random_state=42)
umap_coords = umap_model.fit_transform(rnaseq_sub.values)



In [None]:
# 7. UMAP Plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(umap_coords[mask, 0], umap_coords[mask, 1], s=20, label=cat)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title('UMAP of RNA-seq by Diagnostic Group')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:

# Set index 
if 'sample_id' in rnaseq.columns:
    rnaseq = rnaseq.set_index('sample_id')
if 'sample_id' in clinical_filtered.columns:
    clinical_filtered = clinical_filtered.set_index('sample_id')

# Create copy
clinical_filtered_ = clinical_filtered.copy()

# 1. Create dx_group from primary diagnosis (≥600 threshold)
dx_counts = clinical_filtered['primary_diagnosis.diagnoses'].value_counts()
major_classes = dx_counts[dx_counts >= 600].index.tolist()

def group_diagnosis(dx):
    if pd.isna(dx):
        return np.nan
    elif dx in major_classes:
        return dx
    else:
        return 'Other'

clinical_filtered_['dx_group'] = clinical_filtered['primary_diagnosis.diagnoses'].apply(group_diagnosis)

# 2. Stats
total = len(clinical_filtered_)
print(f"Total samples: {total}")
dx_group_counts = clinical_filtered_['dx_group'].value_counts(dropna=True)
print("\nPatients per group:\n", dx_group_counts)

# 3. Scatter plot of group counts
x = np.arange(len(dx_group_counts))
y = dx_group_counts.values
plt.figure(figsize=(8, 5))
plt.scatter(x, y, c=x)
plt.xticks(x, dx_group_counts.index, rotation=45, ha='right')
plt.ylabel("Number of patients")
plt.title("Patient counts by diagnostic group (>600 grouped)")
plt.tight_layout()
plt.show()

# 4. Align expression data and metadata
meta = clinical_filtered_['dx_group'].dropna()
# meta = clinical_filtered_['dx_group']
# meta = meta[meta.notna() & (meta != 'Other')]
common_ids = rnaseq.index.intersection(meta.index)
print(f"\nMatched samples: {len(common_ids)}")
rnaseq_sub = rnaseq.loc[common_ids]
dx_groups = meta.loc[common_ids].astype('category')

# 5. PCA
pca = PCA(n_components=2)
pca_coords = pca.fit_transform(rnaseq_sub.values)

# 6. PCA plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(pca_coords[mask, 0], pca_coords[mask, 1], s=20, label=cat)      
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA of RNA-seq by Diagnostic Group')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# 7. UMAP
umap_model = umap.UMAP(n_components=2, random_state=42)
umap_coords = umap_model.fit_transform(rnaseq_sub.values)



In [None]:
# 8. UMAP plot
plt.figure(figsize=(8, 6))
for cat in dx_groups.cat.categories:
    mask = dx_groups == cat
    plt.scatter(umap_coords[mask, 0], umap_coords[mask, 1], s=2, alpha= .7, label=cat)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title('UMAP of RNA-seq by Diagnostic Group')
plt.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Unified in major subcluasses. NaN kept. 
clinical_filtered['ajcc_pathologic_t.diagnoses'].value_counts()

In [None]:
clinical_filtered["ajcc_pathologic_t.diagnoses"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['ajcc_pathologic_t.diagnoses'].value_counts().sum()

In [None]:
# Same. Major classes, many NaN (kept)
clinical_filtered['ajcc_pathologic_n.diagnoses'].value_counts()

In [None]:
clinical_filtered["ajcc_pathologic_n.diagnoses"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['ajcc_pathologic_n.diagnoses'].value_counts().sum()

In [None]:
# Grouped and kept, even if there are many nanos 
clinical_filtered['ajcc_pathologic_m.diagnoses'].value_counts()

In [None]:
clinical_filtered["ajcc_pathologic_m.diagnoses"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['ajcc_pathologic_m.diagnoses'].value_counts().sum()

In [None]:
# Na left aside. 
clinical_filtered['tissue_type.samples'].value_counts()

In [None]:
clinical_filtered["tissue_type.samples"].isna().sum()

In [None]:
clinical_filtered.shape[0] - clinical_filtered['tissue_type.samples'].value_counts().sum()

In [None]:
pd.set_option('display.max_rows', None)

In [None]:
# renaming the columns
rename_columns = {
    'gender.demographic': 'gender',
    'vital_status.demographic': 'vital_status',
    'age_at_index.demographic': 'age',
    'ajcc_pathologic_stage.diagnoses': 'ajcc_pathologic_stage',
    'primary_diagnosis.diagnoses': 'diagnosis',
    'ajcc_pathologic_t.diagnoses': 'ajcc_pathologic_t',
    'ajcc_pathologic_n.diagnoses': 'ajcc_pathologic_n',
    'ajcc_pathologic_m.diagnoses': 'ajcc_pathologic_m',
    'tissue_type.samples': 'tissue_type'}

clinical_filtered = clinical_filtered.rename(columns=rename_columns)

In [None]:
# count nans
clinical_filtered.isna().sum()

In [None]:
#count samples with at least one nan
(clinical_filtered.isna().sum(axis=1) > 0).sum()


In [None]:
#count all nan
clinical_filtered.isna().sum(axis=1).sum()

In [None]:
# Group  the stages 
stage_map = {
    "Stage 0": "0",
    "Stage I": "1", "Stage IA": "1", "Stage IB": "1", "Stage IS": "1",
    "Stage II": "2", "Stage IIA": "2", "Stage IIB": "2", "Stage IIC": "2",
    "Stage III": "3", "Stage IIIA": "3", "Stage IIIB": "3", "Stage IIIC": "3",
    "Stage IV": "4", "Stage IVA": "4", "Stage IVB": "4", "Stage IVC": "4",
    "Stage X": pd.NA,
    "Not Reported": pd.NA
}
clinical_filtered["ajcc_pathologic_stage"] = clinical_filtered["ajcc_pathologic_stage"].replace(stage_map)

In [None]:
print(clinical_filtered['ajcc_pathologic_stage'].value_counts())

In [None]:
# Put class 0 in 1
clinical_filtered["ajcc_pathologic_stage"] = clinical_filtered["ajcc_pathologic_stage"].replace("0", "1")

In [None]:
# Group the diagnosis 
# Major classes (≥600 patients)
dx_counts = clinical_filtered['diagnosis'].value_counts()
major_classes = dx_counts[dx_counts >= 600].index.tolist()

# Low-frequency diagnoses with 'Other'
clinical_filtered['diagnosis'] = clinical_filtered['diagnosis'].apply(
    lambda dx: dx if pd.isna(dx) or dx in major_classes else 'Other'
)

In [None]:
print(clinical_filtered['diagnosis'].value_counts())

In [None]:
# new categories for pathologic T: 1, 2, 3 or higher

clinical_filtered["ajcc_pathologic_t"] = clinical_filtered["ajcc_pathologic_t"].replace({
    # Tis and T0 → 0 (non-invasive/in situ)
    "Tis": "0", "T0": "0",

    # T1 group
    "T1": "1", "T1a": "1", "T1b": "1", "T1c": "1",
    "T1b1": "1", "T1b2": "1", "T1a1": "1",

    # T2 group
    "T2": "2", "T2a": "2", "T2b": "2", "T2c": "2",
    "T2a1": "2", "T2a2": "2",

    # T3 + T4 group
    "T3": "3", "T3a": "3", "T3b": "3", "T3c": "3",
    "T4": "3", "T4a": "3", "T4b": "3", "T4c": "3", "T4d": "3",

    # Unknown
    "TX": "X"
})


In [None]:
print(clinical_filtered['ajcc_pathologic_t'].value_counts())

In [None]:
# Put 0 in 1, and X in NaN
clinical_filtered["ajcc_pathologic_t"] = clinical_filtered["ajcc_pathologic_t"].replace("0", "1")
clinical_filtered["ajcc_pathologic_t"] = clinical_filtered["ajcc_pathologic_t"].replace("X", pd.NA)

In [None]:
# new categories for pathologic N: 0, 1 or higher
clinical_filtered["ajcc_pathologic_n"] = clinical_filtered["ajcc_pathologic_n"].replace({
    # N0 group
    "N0": "0", "N0 (i-)": "0", "N0 (i+)": "0", "N0 (mol+)": "0",
    
    # N1 group
    "N1": "1", "N1a": "1", "N1b": "1", "N1c": "1", "N1mi": "1",
    
    # N2 group
    "N2": "2", "N2a": "2", "N2b": "2", "N2c": "2",
    
    # N3 group mapped into "2"
    "N3": "2", "N3a": "2", "N3b": "2", "N3c": "2",
    
    # Unknown
    "NX": "X"
})

In [None]:
print(clinical_filtered['ajcc_pathologic_n'].value_counts())

In [None]:
# Put X in NaN
clinical_filtered["ajcc_pathologic_n"] = clinical_filtered["ajcc_pathologic_n"].replace("X", pd.NA)

In [None]:
# new categories for pathologic M: 0, 1 or X
clinical_filtered["ajcc_pathologic_m"] = clinical_filtered["ajcc_pathologic_m"].replace({
    "M0": "M0", 
    "cM0 (i+)": "M0",
    "M1": "M1", 
    "M1a": "M1",
    "M1b": "M1",
    "M1c": "M1",
    "MX": "MX"
})

In [None]:
print(clinical_filtered['ajcc_pathologic_m'].value_counts())

In [None]:
# Put X in NaN
clinical_filtered["ajcc_pathologic_m"] = clinical_filtered["ajcc_pathologic_m"].replace("MX", pd.NA)

In [None]:
# Select the relevant columns (assuming cancertype is already merged in)
cols_to_check = [col for col in clinical_filtered.columns if col != 'sample_id']
cols_to_check = [col for col in cols_to_check if col != 'cancertype']

# Group by cancertype and count NaNs
nan_counts = clinical_filtered.groupby('cancertype')[cols_to_check].apply(lambda df: df.isna().sum())

# Plot
nan_counts.plot(kind='bar', figsize=(14, 6), width=0.85)

plt.title('Number of NaNs per Cancer Type per Column after cleaning')
plt.ylabel('Number of Missing Values')
plt.xlabel('Cancer Type')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title="Variable", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.show()

In [None]:
# print new value counts
for column in clinical_filtered.columns:
    if column == 'sample_id' or column == 'age' or column == 'cancertype':
        continue
    print(f"Value counts for {column}:")
    print(clinical_filtered[column].value_counts())
    print()

## Dealing with Missing values

In [None]:
# print rows with "X" values

rows_with_x = clinical_filtered[(clinical_filtered == 'X').any(axis=1)]
print(rows_with_x.shape[0])

In [None]:
# # drop rows with "X" values

# clinical = clinical[~(clinical == 'X').any(axis=1)]
# clinical.shape

In [None]:
# save cleaned clinical data

clinical_filtered.to_csv('../datasets_TCGA/merged/reduced_clinical.csv', index=True)