# STICI Algorithm Understanding Tutorial

## Split-Transformer with Integrated Convolutions for Genotype Imputation

This notebook provides a step-by-step understanding of the STICI algorithm using sample data.

### Overview
STICI is a deep learning model that uses a novel variation of the Transformer architecture for genotype imputation. It combines:
- Split-Transformer architecture
- Integrated convolutions
- Advanced preprocessing techniques

### Key Features:
1. **Split-Transformer**: Novel variation of transformer architecture
2. **Integrated Convolutions**: CNN components for local pattern recognition
3. **Masking Strategy**: Random masking during training for robustness
4. **Haploid/Diploid Support**: Handles both haploid and diploid genotype data


In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Libraries imported successfully!")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 1. Understanding the STICI Workflow

The STICI algorithm follows this general workflow:

```
PROGRAM STICI:
  Read the data;
  Perform one-hot encoding on the data;
  Partition the data into training, validation, and test sets;
  IF (data is diploid)
      THEN break them into haploids but keep the order intact;
      ELSE do nothing;
  ENDIF;
  FOR each iteration
      Shuffle haploids in training and validation sets separately;
      FOR each training batch
          Randomly select MaskR% of the values and replace with missing values;
          Train the model using the training batch;
      ENDFOR;
      FOR each validation batch
          Randomly select MaskR% of the values and replace with missing values;
          Evaluate the model using the validation batch;
      ENDFOR;
  ENDFOR;
  Perform prediction using the model on the test set;
  IF (data is diploid)
      THEN replace each two consecutive test samples with respective diploid
      ELSE do nothing;
  ENDIF;
  Save the resulting predictions into a file;
END.
```

In [None]:
# Create sample genotype data to demonstrate STICI concepts
np.random.seed(42)

# Simulate genotype data
# 0 = homozygous reference, 1 = heterozygous, 2 = homozygous alternate, -1 = missing
n_samples = 100
n_variants = 50

# Generate sample genotype matrix
genotype_data = np.random.choice([0, 1, 2], size=(n_samples, n_variants), p=[0.5, 0.3, 0.2])

# Introduce some missing values (represented as -1)
missing_mask = np.random.random((n_samples, n_variants)) < 0.1  # 10% missing
genotype_data[missing_mask] = -1

print(f"Generated genotype data shape: {genotype_data.shape}")
print(f"Missing values: {np.sum(genotype_data == -1)} ({np.sum(genotype_data == -1)/(n_samples*n_variants)*100:.1f}%)")
print(f"Genotype distribution:")
unique, counts = np.unique(genotype_data, return_counts=True)
for val, count in zip(unique, counts):
    if val == -1:
        print(f"  Missing (-1): {count} ({count/(n_samples*n_variants)*100:.1f}%)")
    else:
        print(f"  Genotype {val}: {count} ({count/(n_samples*n_variants)*100:.1f}%)")

In [None]:
# Visualize the genotype data
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Heatmap of genotype data (first 20 samples, first 30 variants)
subset_data = genotype_data[:20, :30]
im = axes[0].imshow(subset_data, cmap='viridis', aspect='auto')
axes[0].set_title('Genotype Data Heatmap\n(First 20 samples, 30 variants)')
axes[0].set_xlabel('Variants')
axes[0].set_ylabel('Samples')
plt.colorbar(im, ax=axes[0], label='Genotype Value')

# Plot 2: Distribution of genotype values
genotype_counts = pd.Series(genotype_data.flatten()).value_counts().sort_index()
axes[1].bar(genotype_counts.index, genotype_counts.values, 
           color=['red', 'blue', 'green', 'orange'])
axes[1].set_title('Distribution of Genotype Values')
axes[1].set_xlabel('Genotype Value')
axes[1].set_ylabel('Count')
axes[1].set_xticks([-1, 0, 1, 2])
axes[1].set_xticklabels(['Missing', 'Hom Ref (0/0)', 'Het (0/1)', 'Hom Alt (1/1)'])

plt.tight_layout()
plt.show()

## 2. Data Preprocessing - One-Hot Encoding

STICI uses one-hot encoding to represent genotype data. Each genotype value is converted to a binary vector.

In [None]:
def one_hot_encode_genotypes(genotype_matrix):
    """
    Convert genotype matrix to one-hot encoded format
    0 -> [1, 0, 0, 0]  # homozygous reference
    1 -> [0, 1, 0, 0]  # heterozygous
    2 -> [0, 0, 1, 0]  # homozygous alternate
    -1 -> [0, 0, 0, 1] # missing
    """
    n_samples, n_variants = genotype_matrix.shape
    encoded = np.zeros((n_samples, n_variants, 4))
    
    for i in range(n_samples):
        for j in range(n_variants):
            val = genotype_matrix[i, j]
            if val == 0:
                encoded[i, j, 0] = 1  # homozygous reference
            elif val == 1:
                encoded[i, j, 1] = 1  # heterozygous
            elif val == 2:
                encoded[i, j, 2] = 1  # homozygous alternate
            else:  # val == -1 (missing)
                encoded[i, j, 3] = 1  # missing
    
    return encoded

# Apply one-hot encoding
encoded_data = one_hot_encode_genotypes(genotype_data)
print(f"One-hot encoded data shape: {encoded_data.shape}")
print(f"Original shape: {genotype_data.shape}")
print(f"Encoding adds dimension for 4 possible states: [Hom_Ref, Het, Hom_Alt, Missing]")

# Show example of encoding
print("\nExample encoding for first sample, first 5 variants:")
for i in range(5):
    original_val = genotype_data[0, i]
    encoded_val = encoded_data[0, i, :]
    genotype_names = ['Hom_Ref', 'Het', 'Hom_Alt', 'Missing']
    decoded_name = genotype_names[np.argmax(encoded_val)]
    print(f"  Variant {i}: {original_val} -> {encoded_val} ({decoded_name})")

## 3. Diploid to Haploid Conversion

For diploid data, STICI converts each diploid genotype into two haploid sequences while maintaining order.

In [None]:
def diploid_to_haploid(genotype_matrix):
    """
    Convert diploid genotypes to haploid pairs
    0 (0/0) -> [0, 0]
    1 (0/1) -> [0, 1] 
    2 (1/1) -> [1, 1]
    -1 (missing) -> [-1, -1]
    """
    n_samples, n_variants = genotype_matrix.shape
    haploid_data = np.zeros((n_samples * 2, n_variants), dtype=int)
    
    for i in range(n_samples):
        for j in range(n_variants):
            val = genotype_matrix[i, j]
            if val == 0:  # 0/0
                haploid_data[2*i, j] = 0
                haploid_data[2*i+1, j] = 0
            elif val == 1:  # 0/1
                haploid_data[2*i, j] = 0
                haploid_data[2*i+1, j] = 1
            elif val == 2:  # 1/1
                haploid_data[2*i, j] = 1
                haploid_data[2*i+1, j] = 1
            else:  # missing
                haploid_data[2*i, j] = -1
                haploid_data[2*i+1, j] = -1
    
    return haploid_data

# Convert to haploid
haploid_data = diploid_to_haploid(genotype_data)
print(f"Diploid data shape: {genotype_data.shape}")
print(f"Haploid data shape: {haploid_data.shape}")
print(f"Each diploid sample becomes 2 haploid samples")

# Show example conversion
print("\nExample diploid to haploid conversion (first sample, first 10 variants):")
print(f"Diploid:   {genotype_data[0, :10]}")
print(f"Haploid 1: {haploid_data[0, :10]}")
print(f"Haploid 2: {haploid_data[1, :10]}")

## 4. Masking Strategy for Training

STICI uses a random masking strategy during training where a percentage of values are randomly masked (set to missing) to teach the model to impute missing values.

In [None]:
def apply_random_masking(data, mask_rate=0.5):
    """
    Apply random masking to training data
    mask_rate: percentage of values to mask (0.5 = 50%)
    """
    masked_data = data.copy()
    n_samples, n_variants = data.shape
    
    # Create random mask
    mask = np.random.random((n_samples, n_variants)) < mask_rate
    
    # Apply mask (set masked values to -1)
    masked_data[mask] = -1
    
    return masked_data, mask

# Demonstrate masking on haploid data
masked_haploid, mask = apply_random_masking(haploid_data, mask_rate=0.3)

print(f"Original missing values: {np.sum(haploid_data == -1)}")
print(f"After masking: {np.sum(masked_haploid == -1)}")
print(f"Additional masked values: {np.sum(mask)}")
print(f"Masking rate: {np.sum(mask)/(haploid_data.shape[0]*haploid_data.shape[1])*100:.1f}%")

# Visualize masking effect
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Original data (subset)
subset_orig = haploid_data[:20, :30]
im1 = axes[0].imshow(subset_orig, cmap='viridis', aspect='auto')
axes[0].set_title('Original Haploid Data')
axes[0].set_xlabel('Variants')
axes[0].set_ylabel('Haploid Samples')
plt.colorbar(im1, ax=axes[0])

# Mask
subset_mask = mask[:20, :30].astype(int)
im2 = axes[1].imshow(subset_mask, cmap='Reds', aspect='auto')
axes[1].set_title('Applied Mask (Red = Masked)')
axes[1].set_xlabel('Variants')
axes[1].set_ylabel('Haploid Samples')
plt.colorbar(im2, ax=axes[1])

# Masked data
subset_masked = masked_haploid[:20, :30]
im3 = axes[2].imshow(subset_masked, cmap='viridis', aspect='auto')
axes[2].set_title('After Masking')
axes[2].set_xlabel('Variants')
axes[2].set_ylabel('Haploid Samples')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

## 5. Data Splitting and Preparation

STICI splits data into training, validation, and test sets. The training and validation sets are used with masking, while the test set is used for final evaluation.

In [None]:
# Split data into train, validation, and test sets
def split_data(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """
    Split data into train, validation, and test sets
    """
    n_samples = data.shape[0]
    
    # Calculate split indices
    train_end = int(n_samples * train_ratio)
    val_end = int(n_samples * (train_ratio + val_ratio))
    
    # Split data
    train_data = data[:train_end]
    val_data = data[train_end:val_end]
    test_data = data[val_end:]
    
    return train_data, val_data, test_data

# Split the haploid data
train_data, val_data, test_data = split_data(haploid_data)

print(f"Original haploid data shape: {haploid_data.shape}")
print(f"Training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")
print(f"Test data shape: {test_data.shape}")

# Show data distribution across splits
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (data_split, title) in enumerate([(train_data, 'Training'), 
                                        (val_data, 'Validation'), 
                                        (test_data, 'Test')]):
    counts = pd.Series(data_split.flatten()).value_counts().sort_index()
    axes[i].bar(counts.index, counts.values, color=['red', 'blue', 'green'])
    axes[i].set_title(f'{title} Data Distribution')
    axes[i].set_xlabel('Genotype Value')
    axes[i].set_ylabel('Count')
    axes[i].set_xticks([-1, 0, 1])
    axes[i].set_xticklabels(['Missing', '0', '1'])

plt.tight_layout()
plt.show()

## 6. STICI Architecture Components

STICI combines several key architectural components:

### 6.1 Split-Transformer Architecture
The Split-Transformer processes sequences in chunks, allowing for better handling of long genomic sequences.

### 6.2 Integrated Convolutions
Convolutional layers capture local patterns in genomic data before feeding into the transformer.

### 6.3 Attention Mechanism
Multi-head attention captures long-range dependencies between variants.

In [None]:
# Simulate STICI architecture components (simplified demonstration)
import matplotlib.patches as patches

def visualize_stici_architecture():
    """
    Create a visual representation of STICI architecture
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Define component positions and sizes
    components = [
        {'name': 'Input\n(One-hot Encoded)', 'pos': (1, 1), 'size': (2, 1), 'color': 'lightblue'},
        {'name': 'Convolutional\nLayers', 'pos': (1, 3), 'size': (2, 1), 'color': 'lightgreen'},
        {'name': 'Split-Transformer\nBlocks', 'pos': (1, 5), 'size': (2, 1.5), 'color': 'lightyellow'},
        {'name': 'Multi-Head\nAttention', 'pos': (4, 5.5), 'size': (1.5, 0.8), 'color': 'lightcoral'},
        {'name': 'Feed Forward\nNetwork', 'pos': (4, 4.5), 'size': (1.5, 0.8), 'color': 'lightcoral'},
        {'name': 'Output Layer\n(Imputed Genotypes)', 'pos': (1, 7.5), 'size': (2, 1), 'color': 'lightpink'}
    ]
    
    # Draw components
    for comp in components:
        rect = patches.Rectangle(comp['pos'], comp['size'][0], comp['size'][1], 
                               linewidth=2, edgecolor='black', facecolor=comp['color'])
        ax.add_patch(rect)
        
        # Add text
        text_x = comp['pos'][0] + comp['size'][0]/2
        text_y = comp['pos'][1] + comp['size'][1]/2
        ax.text(text_x, text_y, comp['name'], ha='center', va='center', 
               fontsize=10, fontweight='bold')
    
    # Draw arrows
    arrows = [
        ((2, 2), (2, 3)),      # Input to Conv
        ((2, 4), (2, 5)),      # Conv to Transformer
        ((3, 5.75), (4, 5.75)), # Transformer to Attention
        ((3, 5.25), (4, 5.25)), # Transformer to FFN
        ((2, 6.5), (2, 7.5))   # Transformer to Output
    ]
    
    for start, end in arrows:
        ax.annotate('', xy=end, xytext=start,
                   arrowprops=dict(arrowstyle='->', lw=2, color='black'))
    
    ax.set_xlim(0, 6)
    ax.set_ylim(0, 9)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('STICI Architecture Overview', fontsize=16, fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.show()

visualize_stici_architecture()

## 7. Key STICI Parameters

Understanding the important parameters used in STICI:

In [None]:
# STICI Key Parameters (from the GitHub repository)
stici_params = {
    'Architecture': {
        'chunk_size': 2048,           # Chunk size in terms of SNPs/SVs
        'chunk_overlap': 128,         # Chunk overlap in terms of SNPs/SVs
        'sites_per_model': 6144,      # Number of SNPs/SVs used per model
        'embed_dim': 128,             # Embedding dimension size
        'num_attention_heads': 16,    # Number of attention heads
    },
    'Training': {
        'max_masking_rate': 0.99,     # Maximum masking rate
        'min_masking_rate': 0.5,      # Minimum masking rate
        'learning_rate': 0.0005,      # Learning rate
        'batch_size_per_gpu': 4,      # Batch size per GPU
        'max_epochs': 1000,           # Maximum number of epochs
        'val_n_batches': 8,           # Number of batches for validation
    },
    'Data': {
        'use_r2_loss': True,          # Whether to use R^2 loss
        'random_seed': 2022,          # Random seed for reproducibility
    }
}

# Display parameters in a formatted way
print("STICI Key Parameters:")
print("=" * 50)
for category, params in stici_params.items():
    print(f"\n{category}:")
    print("-" * 20)
    for param, value in params.items():
        print(f"  {param:<20}: {value}")

## 8. Evaluation Metrics

STICI uses several metrics to evaluate imputation performance:

In [None]:
def calculate_imputation_metrics(true_values, predicted_values):
    """
    Calculate common imputation evaluation metrics
    """
    # Remove missing values for evaluation
    mask = (true_values != -1) & (predicted_values != -1)
    true_clean = true_values[mask]
    pred_clean = predicted_values[mask]
    
    if len(true_clean) == 0:
        return {'accuracy': 0, 'r2': 0, 'correlation': 0}
    
    # Accuracy (exact match)
    accuracy = np.mean(true_clean == pred_clean)
    
    # R-squared
    ss_res = np.sum((true_clean - pred_clean) ** 2)
    ss_tot = np.sum((true_clean - np.mean(true_clean)) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0
    
    # Correlation
    correlation = np.corrcoef(true_clean, pred_clean)[0, 1] if len(true_clean) > 1 else 0
    
    return {
        'accuracy': accuracy,
        'r2': r2,
        'correlation': correlation,
        'n_compared': len(true_clean)
    }

# Simulate some predictions for demonstration
np.random.seed(42)
# Create simulated predictions (with some noise)
test_subset = test_data[:10, :20]  # Small subset for demo
simulated_predictions = test_subset.copy()

# Add some prediction errors
error_mask = np.random.random(test_subset.shape) < 0.1  # 10% error rate
simulated_predictions[error_mask] = 1 - simulated_predictions[error_mask]  # Flip values

# Calculate metrics
metrics = calculate_imputation_metrics(test_subset.flatten(), simulated_predictions.flatten())

print("Imputation Evaluation Metrics (Simulated):")
print("=" * 40)
print(f"Accuracy:     {metrics['accuracy']:.3f}")
print(f"R-squared:    {metrics['r2']:.3f}")
print(f"Correlation:  {metrics['correlation']:.3f}")
print(f"Samples compared: {metrics['n_compared']}")

# Visualize prediction vs truth
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Heatmap comparison
im1 = axes[0].imshow(test_subset, cmap='viridis', aspect='auto')
axes[0].set_title('True Values')
axes[0].set_xlabel('Variants')
axes[0].set_ylabel('Samples')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(simulated_predictions, cmap='viridis', aspect='auto')
axes[1].set_title('Predicted Values')
axes[1].set_xlabel('Variants')
axes[1].set_ylabel('Samples')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

## 9. STICI Usage Examples

Based on the STICI repository, here are example command-line usages:

In [None]:
# STICI Command Line Examples
stici_examples = {
    'Training Mode': {
        'description': 'Train STICI model on reference data',
        'command': '''python3 STICI_V1.1.py \\
    --mode train \\
    --ref reference_data.vcf \\
    --tihp false \\
    --save-dir ./stici_models \\
    --cs 2048 \\
    --co 128 \\
    --max-mr 0.99 \\
    --min-mr 0.5 \\
    --epochs 1000 \\
    --lr 0.0005 \\
    --batch-size-per-gpu 4'''
    },
    'Imputation Mode': {
        'description': 'Use trained model to impute missing genotypes',
        'command': '''python3 STICI_V1.1.py \\
    --mode impute \\
    --ref reference_data.vcf \\
    --target target_data.vcf \\
    --tihp false \\
    --save-dir ./stici_models \\
    --use-trt true \\
    --compress-results true'''
    }
}

print("STICI Usage Examples:")
print("=" * 60)

for mode, info in stici_examples.items():
    print(f"\n{mode}:")
    print(f"Description: {info['description']}")
    print("Command:")
    print(info['command'])
    print("-" * 40)

## 10. Summary and Key Takeaways

### STICI Algorithm Strengths:

1. **Novel Architecture**: Combines Split-Transformer with integrated convolutions
2. **Flexible Input**: Handles both haploid and diploid genotype data
3. **Robust Training**: Uses random masking strategy for better generalization
4. **Scalable**: Designed for large genomic datasets with chunking strategy
5. **High Performance**: Superior accuracy compared to traditional imputation methods

### Key Components:
- **Data Preprocessing**: One-hot encoding and diploid-to-haploid conversion
- **Masking Strategy**: Random masking during training (50-99% masking rates)
- **Architecture**: Split-Transformer blocks with multi-head attention
- **Evaluation**: Multiple metrics including accuracy, R², and correlation

### Applications:
- Genotype imputation for GWAS studies
- Missing variant imputation
- Population genetics research
- Personalized medicine applications

This tutorial provides a foundation for understanding STICI. For actual implementation, refer to the original repository: https://github.com/shilab/STICI