In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Set, Dict, List, Tuple
import logging
from datetime import datetime

class RadiationAnalysisPipeline:
    def __init__(self, vcf_dir: str, output_dir: str, metadata_file: str):

       
        self.vcf_dir = vcf_dir
        self.output_dir = output_dir
        self.plots_dir = os.path.join(output_dir, 'plots')
        self.filtered_dir = os.path.join(output_dir, 'filtered_vcfs')
        
        
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.plots_dir, exist_ok=True)
        os.makedirs(self.filtered_dir, exist_ok=True)
        
        
        self.setup_logging()
        
        
        self.logger.info("Initializing RadiationAnalysisPipeline...")
        
        
        try:
            self.metadata = pd.read_csv(metadata_file)
            self.logger.info(f"Read metadata for {len(self.metadata)} samples")
            self.logger.info(f"Metadata columns: {self.metadata.columns.tolist()}")
        except Exception as e:
            self.logger.error(f"Error reading metadata file: {str(e)}")
            raise
        
        
        self.analysis_results = pd.DataFrame()
        self.control_variants = set()
        
        self.logger.info("Pipeline initialization complete")

    def setup_logging(self):
        """Configure logging for the pipeline."""
        log_file = os.path.join(self.output_dir, 
                               f'radiation_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def read_vcf_with_headers(self, vcf_file: str) -> Tuple[List[str], pd.DataFrame]:
        """Read a VCF file while preserving headers and data structure."""
        self.logger.info(f"Reading VCF file: {vcf_file}")
        headers = []
        variant_data = []
        
        try:
            with open(vcf_file, 'r') as f:
                for line in f:
                    if line.startswith('#'):
                        headers.append(line)
                        if line.startswith('#CHROM'):
                            column_names = line.strip('#\n').split('\t')
                    else:
                        variant_data.append(line.strip().split('\t'))
            
            df = pd.DataFrame(variant_data, columns=column_names)
            self.logger.info(f"Successfully read {len(df)} variants")
            return headers, df
            
        except Exception as e:
            self.logger.error(f"Error reading VCF file {vcf_file}: {str(e)}")
            raise

    def process_treatment_sample(self, sample: str) -> Dict:
        """Process a single treatment sample using metadata information."""
        input_vcf = os.path.join(self.vcf_dir, f"{sample}.vcf")
        output_vcf = os.path.join(self.filtered_dir, f"{sample}_filtered.vcf")
        
        # Get sample metadata
        sample_meta = self.metadata[self.metadata['SampleName'] == sample].iloc[0]
        
        # Read and process the file
        headers, variants_df = self.read_vcf_with_headers(input_vcf)
        
        # Create variant identifiers
        variants_df['variant_id'] = variants_df.apply(
            lambda row: f"{row['CHROM']}_{row['POS']}_{row['REF']}_{row['ALT']}", 
            axis=1
        )
        
        # Filter variants
        filtered_df = variants_df[~variants_df['variant_id'].isin(self.control_variants)]
        
        # Write filtered VCF
        with open(output_vcf, 'w') as f:
            for header in headers:
                f.write(header)
            for _, row in filtered_df.iterrows():
                variant_line = '\t'.join(str(row[col]) for col in variants_df.columns 
                                       if col != 'variant_id')
                f.write(f"{variant_line}\n")
        
        # Calculate statistics using metadata information
        stats = {
            'Sample': sample,
            'Total_Variants': len(variants_df),
            'Radiation_Specific_Variants': len(filtered_df),
            'Filtered_Out': len(variants_df) - len(filtered_df),
            'Percent_Specific': round(len(filtered_df) / len(variants_df) * 100, 2),
            'Week': int(sample.split('_W')[1][0]),
            'Exposure_Rate': sample_meta['ExposureRate_mGh'],
            'Total_Exposure': sample_meta['TotalExposure_mG']
        }
        
        return stats

    def process_all_samples(self):
        """Process all treatment samples while filtering out variants found in controls."""
        self.logger.info("Starting analysis of treatment samples...")
        
        # First, identify control samples from metadata
        control_samples = self.metadata[
            self.metadata['ExposureRate_mGh'].astype(str).str.lower() == 'control'
        ]['SampleName'].tolist()
        
        # Collect control variants
        self.logger.info("Collecting variants from control samples...")
        self.control_variants = set()
        
        for control_sample in control_samples:
            control_vcf = os.path.join(self.vcf_dir, f"{control_sample}.vcf")
            if os.path.exists(control_vcf):
                headers, variants_df = self.read_vcf_with_headers(control_vcf)
                variant_ids = variants_df.apply(
                    lambda row: f"{row['CHROM']}_{row['POS']}_{row['REF']}_{row['ALT']}", 
                    axis=1
                )
                self.control_variants.update(variant_ids)
                self.logger.info(f"Added {len(variant_ids)} variants from {control_sample}")
        
        self.logger.info(f"Found {len(self.control_variants)} unique variants in control samples")
        
        # Process treatment samples
        results = []
        treatment_samples = self.metadata[
            self.metadata['ExposureRate_mGh'].astype(str).str.lower() != 'control'
        ]['SampleName'].tolist()
        
        for sample in treatment_samples:
            self.logger.info(f"Processing treatment sample: {sample}")
            try:
                stats = self.process_treatment_sample(sample)
                results.append(stats)
                self.logger.info(f"Successfully processed {sample}")
            except Exception as e:
                self.logger.error(f"Error processing sample {sample}: {str(e)}")
                continue
        
        # Create analysis results DataFrame
        self.analysis_results = pd.DataFrame(results)
        
        # Save results to CSV
        results_file = os.path.join(self.output_dir, 'analysis_results.csv')
        self.analysis_results.to_csv(results_file, index=False)
        self.logger.info(f"Analysis results saved to: {results_file}")
        
        # Log summary statistics
        self.logger.info("\nAnalysis Summary:")
        self.logger.info(f"Total samples processed: {len(results)}")
        self.logger.info(f"Average radiation-specific variants: "
                        f"{self.analysis_results['Radiation_Specific_Variants'].mean():.2f}")
        self.logger.info(f"Average percent specific: "
                        f"{self.analysis_results['Percent_Specific'].mean():.2f}%")

    def create_visualizations(self):
        """Create comprehensive visualizations of the radiation analysis results."""
        self.logger.info("Creating visualizations...")
        
        # Set up the plotting style
        plt.style.use('seaborn')
        
        # Create a figure with multiple subplots
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 15))
        
        # Plot 1: Variants by Exposure Rate
        scatter = ax1.scatter(self.analysis_results['Exposure_Rate'], 
                            self.analysis_results['Radiation_Specific_Variants'],
                            c=self.analysis_results['Week'],
                            s=150,
                            cmap='viridis',
                            alpha=0.7)
        ax1.set_xscale('log')
        ax1.set_title('Radiation-Specific Variants vs Exposure Rate', pad=20)
        ax1.set_xlabel('Exposure Rate (mGh)')
        ax1.set_ylabel('Number of Radiation-Specific Variants')
        plt.colorbar(scatter, ax=ax1, label='Week')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Variants Over Time
        for rate in sorted(self.analysis_results['Exposure_Rate'].unique()):
            rate_data = self.analysis_results[self.analysis_results['Exposure_Rate'] == rate]
            ax2.plot(rate_data['Week'], 
                    rate_data['Radiation_Specific_Variants'],
                    marker='o',
                    label=f'{rate} mGh',
                    linewidth=2,
                    markersize=8)
        ax2.set_title('Accumulation of Radiation-Specific Variants Over Time', pad=20)
        ax2.set_xlabel('Week')
        ax2.set_ylabel('Number of Radiation-Specific Variants')
        ax2.legend(title='Exposure Rate (mGh)', bbox_to_anchor=(1.05, 1))
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Bubble Plot
        scatter = ax3.scatter(self.analysis_results['Total_Exposure'],
                            self.analysis_results['Percent_Specific'],
                            s=self.analysis_results['Radiation_Specific_Variants']/1000,
                            c=self.analysis_results['Week'],
                            cmap='viridis',
                            alpha=0.6)
        ax3.set_xscale('log')
        ax3.set_title('Percentage of Radiation-Specific Variants vs Total Exposure', pad=20)
        ax3.set_xlabel('Total Exposure (mG)')
        ax3.set_ylabel('Percent Specific (%)')
        plt.colorbar(scatter, ax=ax3, label='Week')
        ax3.grid(True, alpha=0.3)
        
        # Add bubble size legend
        max_variants = self.analysis_results['Radiation_Specific_Variants'].max()
        legend_elements = [plt.scatter([], [], s=s/1000, c='gray', alpha=0.6,
                                     label=f'{s:,} variants')
                          for s in [500000, 1000000, 1500000]]
        ax3.legend(handles=legend_elements, title='Number of Variants',
                  bbox_to_anchor=(1.05, 0.5))
        
        # Plot 4: Heatmap
        pivot_data = self.analysis_results.pivot(index='Exposure_Rate',
                                               columns='Week',
                                               values='Radiation_Specific_Variants')
        sns.heatmap(pivot_data, ax=ax4, 
                    annot=True,
                    fmt=',d',
                    cmap='YlOrRd',
                    cbar_kws={'label': 'Number of Variants'})
        ax4.set_title('Distribution of Radiation-Specific Variants\nAcross Time and Exposure Rates', 
                     pad=20)
        ax4.set_xlabel('Week')
        ax4.set_ylabel('Exposure Rate (mGh)')
        
        # Save the plots
        plt.tight_layout()
        plot_path = os.path.join(self.plots_dir, 'radiation_analysis_summary.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        self.logger.info("Visualizations completed and saved")

def run_analysis(vcf_dir: str, output_dir: str, metadata_file: str):
    """Run the complete radiation analysis pipeline."""
    pipeline = RadiationAnalysisPipeline(vcf_dir, output_dir, metadata_file)
    pipeline.process_all_samples()
    pipeline.create_visualizations()
    return pipeline

# usage
if __name__ == "__main__":
    vcf_directory = "./LDR/vcf_files/all_samples/vcf"
    output_directory = "./radiation_analysis_results"
    metadata_file = "./LDR/vcf_files/all_samples/meta_data.csv"
    
    pipeline = run_analysis(vcf_directory, output_directory, metadata_file)

In [None]:
from SigProfilerMatrixGenerator import install as genInstall
genInstall.install('GRCh38')  # Or 'GRCh37' if your VCF uses that build

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from SigProfilerExtractor import sigpro as sig

# Run SigProfilerExtractor and store results
def run_sigprofiler_analysis(input_data, output_dir, project_name):
    """Run SigProfilerExtractor and return results"""
    try:
        results = sig.sigProfilerExtractor(
            input_type="vcf",
            output=output_dir,
            input_data=input_vcf,
            reference_genome="GRCh38",
            minimum_signatures=1,
            maximum_signatures=10,
            nmf_replicates=100,
            cpu=-1 #Can add GPU option here
        )
        return results
    except Exception as e:
        print(f"Error running SigProfilerExtractor: {str(e)}")
        return None

In [None]:
output_dir = "./radiation_analysis_results/filtered_vcfs/output/"
input_vcf = "./radiation_analysis_results/filtered_vcfs/"
project_name ="test"

In [None]:
results = run_sigprofiler_analysis(input_vcf,output_dir, project_name)