In [None]:
print("test")

In [None]:
import matplotlib.pyplot as plt

# 1. Basic visualization of neural activity over time and channels
def plot_neural_data(neural_data, title="Neural Data", n_timepoints=1000):
    """
    Plot first n_timepoints of neural data
    neural_data shape: (time, channels)
    """
    plt.figure(figsize=(12, 6))
    plt.imshow(neural_data[:n_timepoints].T, aspect='auto', cmap='viridis')
    plt.colorbar(label='Activity')
    plt.xlabel('Time (bins)')
    plt.ylabel('Channel')
    plt.title(title)
    plt.show()

# 2. Plot neural data with corresponding labels
def plot_neural_with_labels(neural_data, labels, n_timepoints=1000):
    """
    Plot neural data and corresponding labels
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]})
    
    # Neural data
    im = ax1.imshow(neural_data[:n_timepoints].T, aspect='auto', cmap='viridis')
    ax1.set_title('Neural Activity')
    ax1.set_ylabel('Channel')
    plt.colorbar(im, ax=ax1, label='Activity')
    
    # Labels
    ax2.plot(labels[:n_timepoints], 'r', label='Silence/Speech')
    ax2.set_xlabel('Time (bins)')
    ax2.set_ylabel('Label')
    ax2.set_ylim(-0.1, 1.1)
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

# 3. Plot average activity across channels
def plot_channel_average(neural_data, labels, n_timepoints=1000):
    """
    Plot mean neural activity with labels
    """
    mean_activity = neural_data[:n_timepoints].mean(axis=1)
    
    fig, ax1 = plt.subplots(figsize=(12, 4))
    
    ax1.plot(mean_activity, 'b', label='Mean Activity')
    ax1.set_ylabel('Mean Activity', color='b')
    ax1.tick_params(axis='y', labelcolor='b')
    
    ax2 = ax1.twinx()
    ax2.plot(labels[:n_timepoints], 'r', label='Labels', alpha=0.5)
    ax2.set_ylabel('Label', color='r')
    ax2.tick_params(axis='y', labelcolor='r')
    
    plt.title('Mean Neural Activity vs Labels')
    plt.show()

    
neural_data = np.load("/home/groups/henderj/rzwang/processed_data/neural_data.npy")
labels = np.load("/home/groups/henderj/rzwang/processed_data/labels.npy")
# Basic neural data visualization
plot_neural_data(neural_data)

# Neural data with labels
plot_neural_with_labels(neural_data, labels)

# Average activity
plot_channel_average(neural_data, labels)