# SciTeX Plotting Module - Publication-Ready Visualizations

The `scitex.plt` module extends matplotlib with convenient functions and better defaults for scientific plotting.

In [None]:
import scitex as stx
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt

# Set random seed for reproducibility
np.random.seed(42)

## 1. Basic Plotting with Enhanced Defaults

In [None]:
# Generate sample data
x = np.linspace(0, 4*np.pi, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.sin(x) * np.exp(-x/10)

# Create figure with SciTeX
fig, ax = stx.plt.subplots(figsize=(10, 6))

# Plot multiple lines
ax.plot(x, y1, label='sin(x)', linewidth=2)
ax.plot(x, y2, label='cos(x)', linewidth=2)
ax.plot(x, y3, label='sin(x)·exp(-x/10)', linewidth=2, linestyle='--')

# Use the convenient set_xyt function (x-label, y-label, title)
ax.set_xyt('x (radians)', 'y', 'Trigonometric Functions')

# Add grid and legend
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right')

# Save figure
stx.io.save(fig, './plots/basic_trig.png', dpi=300)
stx.plt.show()

## 2. Statistical Plots

In [None]:
# Generate random data for different groups
groups = ['Control', 'Treatment A', 'Treatment B', 'Treatment C']
data = [
    np.random.normal(100, 15, 50),
    np.random.normal(110, 12, 50),
    np.random.normal(105, 18, 50),
    np.random.normal(115, 10, 50)
]

# Create box plot
fig, (ax1, ax2) = stx.plt.subplots(1, 2, figsize=(12, 6))

# Box plot
bp = ax1.boxplot(data, labels=groups, patch_artist=True)

# Customize box plot colors
colors = ['lightblue', 'lightgreen', 'lightcoral', 'lightyellow']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

ax1.set_xyt('Group', 'Value', 'Box Plot Comparison')
ax1.grid(True, alpha=0.3, axis='y')

# Violin plot
parts = ax2.violinplot(data, positions=range(1, len(groups)+1), showmeans=True)

# Customize violin plot
for pc, color in zip(parts['bodies'], colors):
    pc.set_facecolor(color)
    pc.set_alpha(0.7)

ax2.set_xticks(range(1, len(groups)+1))
ax2.set_xticklabels(groups)
ax2.set_xyt('Group', 'Value', 'Violin Plot Comparison')
ax2.grid(True, alpha=0.3, axis='y')

fig.tight_layout()
stx.io.save(fig, './plots/statistical_comparison.png')
stx.plt.show()

## 3. Heatmaps and 2D Data

In [None]:
# Generate correlation matrix
n_vars = 10
data = np.random.randn(100, n_vars)

# Add some correlations
data[:, 1] = data[:, 0] * 0.7 + np.random.randn(100) * 0.3
data[:, 2] = data[:, 0] * -0.5 + np.random.randn(100) * 0.4
data[:, 5] = data[:, 4] * 0.8 + np.random.randn(100) * 0.2

# Calculate correlation matrix
corr_matrix = np.corrcoef(data.T)

# Create heatmap
fig, ax = stx.plt.subplots(figsize=(10, 8))

# Plot heatmap
im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Correlation Coefficient', rotation=270, labelpad=20)

# Set ticks and labels
var_names = [f'Var{i+1}' for i in range(n_vars)]
ax.set_xticks(range(n_vars))
ax.set_yticks(range(n_vars))
ax.set_xticklabels(var_names, rotation=45, ha='right')
ax.set_yticklabels(var_names)

# Add values to cells
for i in range(n_vars):
    for j in range(n_vars):
        text = ax.text(j, i, f'{corr_matrix[i, j]:.2f}',
                      ha='center', va='center',
                      color='white' if abs(corr_matrix[i, j]) > 0.5 else 'black',
                      fontsize=8)

ax.set_title('Correlation Matrix Heatmap', fontsize=14, pad=20)

fig.tight_layout()
stx.io.save(fig, './plots/correlation_heatmap.png')
stx.plt.show()

## 4. Multi-panel Figures

In [None]:
# Create complex multi-panel figure
fig = plt.figure(figsize=(15, 10))

# Define grid
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Top panel - time series
ax1 = fig.add_subplot(gs[0, :])
t = np.linspace(0, 10, 1000)
signal = np.sin(2*np.pi*t) + 0.5*np.sin(10*np.pi*t) + np.random.randn(1000)*0.1
ax1.plot(t, signal, 'b-', linewidth=0.5)
ax1.set_xyt('Time (s)', 'Amplitude', 'Time Series Signal')
ax1.grid(True, alpha=0.3)

# Middle left - histogram
ax2 = fig.add_subplot(gs[1, 0])
ax2.hist(signal, bins=50, alpha=0.7, color='green', edgecolor='black')
ax2.set_xyt('Amplitude', 'Count', 'Distribution')
ax2.grid(True, alpha=0.3, axis='y')

# Middle center - scatter plot
ax3 = fig.add_subplot(gs[1, 1])
x_scatter = np.random.randn(200)
y_scatter = 2*x_scatter + np.random.randn(200)*0.5
ax3.scatter(x_scatter, y_scatter, alpha=0.6, s=30)
ax3.set_xyt('X', 'Y', 'Scatter Plot')
ax3.grid(True, alpha=0.3)

# Add regression line
z = np.polyfit(x_scatter, y_scatter, 1)
p = np.poly1d(z)
ax3.plot(sorted(x_scatter), p(sorted(x_scatter)), "r--", alpha=0.8)

# Middle right - bar plot
ax4 = fig.add_subplot(gs[1, 2])
categories = ['A', 'B', 'C', 'D', 'E']
values = np.random.randint(20, 100, 5)
bars = ax4.bar(categories, values, color='orange', alpha=0.7)
ax4.set_xyt('Category', 'Value', 'Bar Chart')
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, val in zip(bars, values):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
             str(val), ha='center', va='bottom')

# Bottom panel - 2D contour
ax5 = fig.add_subplot(gs[2, :])
x_contour = np.linspace(-3, 3, 100)
y_contour = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x_contour, y_contour)
Z = np.exp(-(X**2 + Y**2)/2) * np.cos(2*X) * np.sin(2*Y)

contour = ax5.contourf(X, Y, Z, levels=20, cmap='viridis')
plt.colorbar(contour, ax=ax5, label='Value')
ax5.set_xyt('X', 'Y', '2D Contour Plot')

# Add panel labels
for ax, label in zip([ax1, ax2, ax3, ax4, ax5], ['A', 'B', 'C', 'D', 'E']):
    ax.text(0.02, 0.98, label, transform=ax.transAxes,
            fontsize=14, fontweight='bold', va='top')

fig.suptitle('Multi-Panel Figure Example', fontsize=16)
stx.io.save(fig, './plots/multi_panel_figure.png', dpi=300)
stx.plt.show()

## 5. Error Bars and Confidence Intervals

In [None]:
# Generate data with uncertainty
x = np.linspace(0, 10, 20)
y_true = 2 * x + 1
y_measured = y_true + np.random.normal(0, 2, len(x))
y_error = np.random.uniform(1, 3, len(x))

# Fit polynomial
coeffs = np.polyfit(x, y_measured, 1)
y_fit = np.polyval(coeffs, x)

# Calculate confidence intervals
from scipy import stats as sp_stats
confidence = 0.95
predict_mean_se = np.sqrt(np.sum((y_measured - y_fit)**2) / (len(x) - 2))
margin = sp_stats.t.ppf(confidence, len(x) - 2) * predict_mean_se

fig, (ax1, ax2) = stx.plt.subplots(1, 2, figsize=(12, 5))

# Error bars
ax1.errorbar(x, y_measured, yerr=y_error, fmt='o', capsize=5,
             label='Measurements', markersize=6, alpha=0.7)
ax1.plot(x, y_true, 'g--', label='True relationship', linewidth=2)
ax1.plot(x, y_fit, 'r-', label=f'Fit: y = {coeffs[0]:.2f}x + {coeffs[1]:.2f}', linewidth=2)
ax1.set_xyt('X', 'Y', 'Measurements with Error Bars')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Confidence bands
x_smooth = np.linspace(0, 10, 100)
y_smooth = np.polyval(coeffs, x_smooth)

ax2.scatter(x, y_measured, alpha=0.6, label='Data points')
ax2.plot(x_smooth, y_smooth, 'r-', label='Regression line', linewidth=2)
ax2.fill_between(x_smooth, y_smooth - margin, y_smooth + margin,
                 alpha=0.3, color='red', label=f'{int(confidence*100)}% CI')
ax2.set_xyt('X', 'Y', 'Regression with Confidence Interval')
ax2.legend()
ax2.grid(True, alpha=0.3)

fig.tight_layout()
stx.io.save(fig, './plots/error_bars_confidence.png')
stx.plt.show()

## 6. 3D Plotting

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# Generate 3D data
fig = plt.figure(figsize=(15, 5))

# Surface plot
ax1 = fig.add_subplot(131, projection='3d')
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2)) / np.sqrt(X**2 + Y**2 + 0.1)

surf = ax1.plot_surface(X, Y, Z, cmap='coolwarm', alpha=0.8)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title('3D Surface Plot')

# 3D scatter
ax2 = fig.add_subplot(132, projection='3d')
n_points = 500
xs = np.random.normal(0, 1, n_points)
ys = np.random.normal(0, 1, n_points)
zs = xs**2 + ys**2 + np.random.normal(0, 0.1, n_points)

scatter = ax2.scatter(xs, ys, zs, c=zs, cmap='viridis', alpha=0.6)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title('3D Scatter Plot')

# 3D wireframe
ax3 = fig.add_subplot(133, projection='3d')
wire = ax3.plot_wireframe(X, Y, Z, color='blue', alpha=0.3)
ax3.contour(X, Y, Z, zdir='z', offset=-2, cmap='coolwarm')
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')
ax3.set_title('3D Wireframe with Contour')
ax3.set_zlim(-2, 1)

fig.tight_layout()
stx.io.save(fig, './plots/3d_plots.png', dpi=300)
stx.plt.show()

## 7. Time Series and Dual Axes

In [None]:
# Generate time series data
dates = pd.date_range('2024-01-01', periods=365, freq='D')
temperature = 20 + 10*np.sin(2*np.pi*np.arange(365)/365) + np.random.randn(365)*2
sales = 1000 + 200*np.sin(2*np.pi*np.arange(365)/365 + np.pi/4) + \
        np.random.randn(365)*50 + np.arange(365)*0.5

# Create figure with dual axes
fig, ax1 = stx.plt.subplots(figsize=(12, 6))

# Plot temperature
color1 = 'tab:red'
ax1.set_xlabel('Date')
ax1.set_ylabel('Temperature (°C)', color=color1)
ax1.plot(dates, temperature, color=color1, alpha=0.7, linewidth=1)
ax1.tick_params(axis='y', labelcolor=color1)
ax1.grid(True, alpha=0.3)

# Create second y-axis
ax2 = ax1.twinx()
color2 = 'tab:blue'
ax2.set_ylabel('Sales (units)', color=color2)
ax2.plot(dates, sales, color=color2, alpha=0.7, linewidth=1)
ax2.tick_params(axis='y', labelcolor=color2)

# Add title
ax1.set_title('Temperature and Sales Over Time', fontsize=14, pad=20)

# Format x-axis
ax1.xaxis.set_major_locator(plt.MaxNLocator(10))
fig.autofmt_xdate()

# Add moving averages
window = 30
temp_ma = pd.Series(temperature).rolling(window).mean()
sales_ma = pd.Series(sales).rolling(window).mean()

ax1.plot(dates, temp_ma, color='darkred', linewidth=2, 
         label=f'{window}-day MA (Temp)')
ax2.plot(dates, sales_ma, color='darkblue', linewidth=2, 
         label=f'{window}-day MA (Sales)')

# Add legends
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

fig.tight_layout()
stx.io.save(fig, './plots/time_series_dual_axes.png')
stx.plt.show()

## 8. Annotations and Text

In [None]:
# Create data with interesting points
x = np.linspace(0, 10, 100)
y = np.sin(x) * np.exp(-x/10) + np.random.normal(0, 0.05, 100)

fig, ax = stx.plt.subplots(figsize=(10, 6))

# Plot data
ax.plot(x, y, 'b-', linewidth=2, label='Signal')

# Find and annotate maximum
max_idx = np.argmax(y)
ax.plot(x[max_idx], y[max_idx], 'ro', markersize=10)
ax.annotate(f'Maximum\n({x[max_idx]:.2f}, {y[max_idx]:.2f})',
            xy=(x[max_idx], y[max_idx]), xytext=(x[max_idx]+1, y[max_idx]+0.2),
            arrowprops=dict(arrowstyle='->', color='red', lw=2),
            fontsize=12, ha='left',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))

# Add shaded region
ax.axvspan(3, 5, alpha=0.2, color='green', label='Region of Interest')

# Add horizontal line at y=0
ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)

# Add text box
textstr = '\n'.join([
    'Signal Properties:',
    f'Max value: {y[max_idx]:.3f}',
    f'Mean: {np.mean(y):.3f}',
    f'Std: {np.std(y):.3f}'
])
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', bbox=props)

# Add mathematical expression
ax.text(7, -0.3, r'$y = \sin(x) \cdot e^{-x/10} + \epsilon$', 
        fontsize=14, style='italic')

ax.set_xyt('x', 'y', 'Annotated Plot Example')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_ylim(-0.5, 0.8)

stx.io.save(fig, './plots/annotations_example.png')
stx.plt.show()

## 9. Custom Colormaps and Styles

In [None]:
# Create sample data for different plot types
n_points = 1000
x = np.random.randn(n_points)
y = 2*x + np.random.randn(n_points)*0.5
colors = x + y

# Create figure with different styles
fig, axes = stx.plt.subplots(2, 2, figsize=(12, 10))

# Scatter with custom colormap
scatter1 = axes[0, 0].scatter(x, y, c=colors, cmap='viridis', 
                              alpha=0.6, s=20)
axes[0, 0].set_xyt('X', 'Y', 'Viridis Colormap')
plt.colorbar(scatter1, ax=axes[0, 0])

# Hexbin plot
hb = axes[0, 1].hexbin(x, y, gridsize=30, cmap='hot', mincnt=1)
axes[0, 1].set_xyt('X', 'Y', 'Hexbin Density')
plt.colorbar(hb, ax=axes[0, 1])

# 2D histogram
hist2d, xedges, yedges = np.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
im = axes[1, 0].imshow(hist2d.T, origin='lower', extent=extent, 
                       aspect='auto', cmap='plasma')
axes[1, 0].set_xyt('X', 'Y', '2D Histogram')
plt.colorbar(im, ax=axes[1, 0])

# Contour plot with custom levels
H, xedges, yedges = np.histogram2d(x, y, bins=25)
X_edges, Y_edges = np.meshgrid(xedges[:-1], yedges[:-1])
levels = np.linspace(0, H.max(), 10)
cs = axes[1, 1].contourf(X_edges, Y_edges, H.T, levels=levels, cmap='coolwarm')
axes[1, 1].contour(X_edges, Y_edges, H.T, levels=levels, colors='k', 
                   linewidths=0.5, alpha=0.5)
axes[1, 1].set_xyt('X', 'Y', 'Contour Plot')
plt.colorbar(cs, ax=axes[1, 1])

fig.suptitle('Different Visualization Styles for 2D Data', fontsize=14)
fig.tight_layout()
stx.io.save(fig, './plots/custom_styles.png', dpi=300)
stx.plt.show()

## 10. Publication-Ready Figure Template

In [None]:
# Set publication style
plt.style.use('seaborn-v0_8-paper')

# Create figure with specific dimensions (for journal requirements)
# Single column: ~3.5 inches, Double column: ~7 inches
fig_width = 7  # inches
golden_ratio = 1.618
fig_height = fig_width / golden_ratio

fig, (ax1, ax2) = stx.plt.subplots(1, 2, figsize=(fig_width, fig_height))

# Generate publication data
x = np.linspace(0, 5, 50)
y1_mean = np.exp(-x/2) * np.cos(2*np.pi*x)
y1_std = 0.05 * np.ones_like(x)
y1 = y1_mean + np.random.normal(0, y1_std)

y2_mean = np.exp(-x/3) * np.sin(2*np.pi*x)
y2_std = 0.05 * np.ones_like(x)
y2 = y2_mean + np.random.normal(0, y2_std)

# Panel A - with error bands
ax1.plot(x, y1, 'o', markersize=4, color='C0', alpha=0.6, label='Observed')
ax1.plot(x, y1_mean, '-', color='C0', linewidth=2, label='Theory')
ax1.fill_between(x, y1_mean - 2*y1_std, y1_mean + 2*y1_std, 
                 alpha=0.2, color='C0')
ax1.set_xlabel('Time (ms)', fontsize=10)
ax1.set_ylabel('Amplitude (μV)', fontsize=10)
ax1.set_title('(A) Damped Oscillation', fontsize=10, loc='left')
ax1.legend(frameon=False, fontsize=8)
ax1.grid(True, alpha=0.3, linestyle=':')

# Panel B - comparison
ax2.plot(x, y1_mean, '-', color='C0', linewidth=2, label='Signal 1')
ax2.plot(x, y2_mean, '-', color='C1', linewidth=2, label='Signal 2')
ax2.set_xlabel('Time (ms)', fontsize=10)
ax2.set_ylabel('Amplitude (μV)', fontsize=10)
ax2.set_title('(B) Signal Comparison', fontsize=10, loc='left')
ax2.legend(frameon=False, fontsize=8)
ax2.grid(True, alpha=0.3, linestyle=':')

# Adjust layout
fig.tight_layout()

# Save in multiple formats for publication
for fmt in ['png', 'pdf', 'svg']:
    stx.io.save(fig, f'./plots/publication_figure.{fmt}', dpi=300)

stx.plt.show()

# Reset style
plt.style.use('default')

## Best Practices for Scientific Plotting

### 1. **Color Choices**
- Use colorblind-friendly palettes
- Avoid red-green combinations
- Test plots in grayscale

### 2. **Font Sizes**
- Minimum 8pt for publication
- Scale with figure size
- Consistent throughout figure

### 3. **Data Presentation**
- Show individual data points when possible
- Include error bars or confidence intervals
- Avoid chartjunk

### 4. **File Formats**
- **PNG**: For web/presentations (300 dpi)
- **PDF/SVG**: For publications (vector)
- **TIFF**: Some journals require this

### 5. **Reproducibility**
- Save plotting code with data
- Include random seeds
- Document all parameters

In [None]:
# Cleanup
import shutil
from pathlib import Path

# Keep plots directory for reference
if Path('./plots').exists():
    print("Plots saved in ./plots/")
    print("Files created:")
    for p in sorted(Path('./plots').glob('*')):
        print(f"  - {p.name}")