In [1]:
import torch
import copy
import random

In [2]:
# PRIVATE

import submitit

ex_ml = submitit.get_executor(
    slurm_partition='ml_gpu-rtx2080',
    timeout_min=60*24,
)

ex_bosch_cpu = submitit.get_executor(
    slurm_partition='bosch_cpu-cascadelake',
    cpus_per_task=8,
    timeout_min=60*24,
    slurm_gres='gpu:0',
)

ex_bosch = submitit.get_executor(
    slurm_partition='bosch_gpu-rtx2080',
    timeout_min=60*24,
)

ex = submitit.get_executor(
    slurm_partition='alldlc_gpu-rtx2080',
    timeout_min=60*24,
)

In [3]:
def create_mlp(input_size, hidden_sizes, activation="relu", num_outputs=1, weight_multiplier=1.0):
    layers = []
    layer_sizes = [input_size] + hidden_sizes + [num_outputs]
    
    for i in range(len(layer_sizes) - 1):
        linear_layer = torch.nn.Linear(layer_sizes[i], layer_sizes[i+1])
        
        # Apply weight multiplier
        with torch.no_grad():
            linear_layer.weight.mul_(weight_multiplier)
            linear_layer.bias.mul_(weight_multiplier)
        
        layers.append(linear_layer)
        
        if i < len(layer_sizes) - 2:  # Don't add activation after the last layer
            if activation == "relu":
                layers.append(torch.nn.ReLU())
            elif activation == "tanh":
                layers.append(torch.nn.Tanh())
            else:
                raise ValueError("Activation must be either 'relu' or 'tanh'")
    
    return torch.nn.Sequential(*layers)

In [4]:
from time import time
import torch.optim.lr_scheduler as lr_scheduler

def train_student(teacher_mlp, student_mlp, num_features, test_examples, num_steps=1000, batch_size=256, learning_rate=0.001,
                  use_input_optimization=False, gen_learning_rate=0.001, old_generator_prob=0.5, apply_scheduler_to_gen=False, return_saved_generators=False,
                  device=None,
                 ):
    # Use CUDA if available
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher_mlp = teacher_mlp.to(device)
    student_mlp = student_mlp.to(device)

    # Initialize optimizer for the student
    optimizer = torch.optim.Adam(student_mlp.parameters(), lr=learning_rate)
    
    # Add cosine annealing scheduler
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
    
    # Initialize generator for input distribution if input optimization is enabled
    input_generator = None
    gen_optimizer = None
    gen_scheduler = None
    if use_input_optimization:
        class InputGenerator(torch.nn.Module):
            def __init__(self, batch_size, input_dim):
                super().__init__()
                self.means = torch.nn.Parameter(torch.rand(batch_size // 2, input_dim, device=device) * 2 - 1)
                self.log_stds = torch.nn.Parameter(torch.zeros(batch_size // 2, input_dim, device=device))
            
            def forward(self):
                stds = torch.exp(self.log_stds)
                epsilon = torch.randn_like(self.means)
                return self.means + stds * epsilon

        input_generator = InputGenerator(batch_size, num_features).to(device)
        gen_optimizer = torch.optim.Adam(input_generator.parameters(), lr=gen_learning_rate)
        if apply_scheduler_to_gen:
            gen_scheduler = lr_scheduler.CosineAnnealingLR(gen_optimizer, T_max=num_steps)
    
    # Manually implement KL divergence
    def kl_divergence(log_p, log_q):
        return torch.sum(log_p.exp() * (log_p - log_q), dim=1).mean()

    saved_generators = []
    total_time = 0.0
    log_interval = 100
    current_loss = 0.0
    agg_losses = []
    for step in range(num_steps):
        start_time = time()
        if use_input_optimization:
            gen_inputs = torch.clamp(input_generator(), min=-1, max=1)
            
            if step == 0 or random.random() > old_generator_prob:
                random_inputs = torch.rand(batch_size // 2, num_features, device=device) * 2 - 1
            else:
                random_inputs = saved_generators[random.randint(0, len(saved_generators)-1)]['means'].to(device)

            inputs = torch.cat([gen_inputs, random_inputs], dim=0)
        else:
            inputs = torch.rand(batch_size, num_features, device=device) * 2 - 1
        
        teacher_log_probs = torch.log_softmax(teacher_mlp(inputs), dim=1)
        
        student_log_probs = torch.log_softmax(student_mlp(inputs), dim=1)
        
        loss = kl_divergence(teacher_log_probs, student_log_probs)
        
        current_loss += loss.item()
        
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        if use_input_optimization:
            for param in input_generator.parameters():
                param.grad = -param.grad
            gen_optimizer.step()
            gen_optimizer.zero_grad()
            if apply_scheduler_to_gen:
                gen_scheduler.step()
            with torch.no_grad():
                input_generator.means.data.clamp_(-1, 1)
        
        step_time = time() - start_time
        total_time += step_time

        if (step + 1) % log_interval == 0:
            avg_loss = current_loss / log_interval
            avg_time_per_step = total_time / (step + 1)
            print(f"Step [{step+1}/{num_steps}], Average Loss: {avg_loss:.4f}, Current Loss: {loss.item():.4f}, Avg Time/Step: {avg_time_per_step:.4f}s, LR {optimizer.param_groups[0]['lr']}")
            agg_losses.append(avg_loss)
            current_loss = 0.0  # Reset current_loss after logging

        if step % 100 == 0 and use_input_optimization:
                saved_generators.append({name: param.detach().clone().cpu() for name, param in input_generator.named_parameters()})

    print("Training completed!")

    # Test the student model in terms of KL
    with torch.no_grad():
        student_log_probs = torch.log_softmax(student_mlp(test_examples.to(device)), dim=1)
        teacher_log_probs = torch.log_softmax(teacher_mlp(test_examples.to(device)), dim=1)
        kl_div = kl_divergence(teacher_log_probs, student_log_probs)
    print(f"KL Divergence: {kl_div.item():.4f}")

    if return_saved_generators and use_input_optimization:
        return agg_losses, saved_generators
    else:
        return agg_losses, None

In [None]:
def setup_experiment(
    num_features=10,
    teacher_hidden_sizes=[32, 32],
    student_hidden_sizes=[64, 64, 32],
    teacher_activation="tanh",
    student_activation="relu",
    teacher_weight_multiplier=8.0,
    num_outputs=2,
    num_test_examples=100,
    grid_size=3
):
    # Create teacher and student MLPs
    teacher_mlp = create_mlp(input_size=num_features, hidden_sizes=teacher_hidden_sizes, 
                             activation=teacher_activation, num_outputs=num_outputs, 
                             weight_multiplier=teacher_weight_multiplier)
    teacher_mlp.requires_grad_(False)
    
    student_mlp = create_mlp(input_size=num_features, hidden_sizes=student_hidden_sizes, 
                             activation=student_activation, num_outputs=num_outputs)

    # Create test examples
    test_examples = torch.rand(num_test_examples, num_features) * 2 - 1

    # Evaluate teacher_mlp on a grid
    grid_points = torch.linspace(-1, 1, grid_size)
    grid = torch.stack(torch.meshgrid(*[grid_points for _ in range(num_features)], indexing='ij')).reshape(num_features, -1).t()

    with torch.no_grad():
        logits = teacher_mlp(grid)
        probs = torch.softmax(logits, dim=1)
        class_0_probs = probs[:, 0]

    min_prob = class_0_probs.min().item()
    max_prob = class_0_probs.max().item()

    print(f"Minimum probability for class 0: {min_prob:.4f}")
    print(f"Maximum probability for class 0: {max_prob:.4f}")

    return teacher_mlp, student_mlp, test_examples

# Usage example with default values:
num_features = 2
teacher_mlp, student_mlp, test_examples = setup_experiment(num_features=num_features)

# Or you can override specific parameters:
# teacher_mlp, student_mlp, test_examples = setup_experiment(
#     num_features=12,
#     teacher_hidden_sizes=[64, 64],
#     student_hidden_sizes=[128, 128, 64],
#     teacher_weight_multiplier=10.0
# )

def get_student_mlp_copy():
    return copy.deepcopy(student_mlp)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_mlp_predictions(mlp1, mlp2=None, resolution=100):
    # Create a grid of points
    x = np.linspace(-1, 1, resolution)
    y = np.linspace(-1, 1, resolution)
    X, Y = np.meshgrid(x, y)
    
    # Reshape the grid points into a 2D tensor
    xy = np.vstack([X.ravel(), Y.ravel()]).T
    xy_tensor = torch.FloatTensor(xy)
    
    # Get predictions from the MLPs
    with torch.no_grad():
        logits1 = mlp1(xy_tensor)
        assert logits1.shape[1] == 2
        probs1 = torch.softmax(logits1, dim=1)
        class_1_probs1 = probs1[:, 0].numpy()
        visualize_probs = class_1_probs1
        
        if mlp2 is not None:
            logits2 = mlp2(xy_tensor)
            assert logits2.shape[1] == 2
            probs2 = torch.softmax(logits2, dim=1)
            class_1_probs2 = probs2[:, 0].numpy()
            visualize_probs = class_1_probs2 - class_1_probs1
    
    # Reshape the predictions back to the grid shape
    Z = visualize_probs.reshape(X.shape)
    
    # Plot the results
    plt.figure(figsize=(10, 8))
    plt.contourf(X, Y, Z, levels=20, cmap='viridis')
    plt.colorbar(label='Probability of Class 1')
    plt.title('MLP Predictions - Probability of Class 1')
    plt.xlabel('Input 1')
    plt.ylabel('Input 2')
    plt.show()

visualize_mlp_predictions(teacher_mlp)


In [None]:
# Train the copy of student_mlp
student_mlp_standard = get_student_mlp_copy()
train_student(teacher_mlp, student_mlp_standard, num_features, test_examples, num_steps=10000)

In [None]:
student_mlp_adaptive = get_student_mlp_copy()
losses, saved_generators = train_student(teacher_mlp, student_mlp_adaptive, num_features, test_examples, num_steps=10000,
                                 use_input_optimization=True, gen_learning_rate=0.01, return_saved_generators=True, old_generator_prob=1.)

In [None]:
visualize_mlp_predictions(teacher_mlp)
visualize_mlp_predictions(student_mlp_adaptive, teacher_mlp)
#visualize_mlp_predictions(student_mlp_standard, student_mlp_adaptive)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract the first 10 means
first_means = np.array([s['means'][:100,:2].detach().numpy() for s in saved_generators])

print(first_means.shape)

# Set up the plot
fig, ax = plt.subplots(figsize=(10, 8))

# Define a colormap for the paths
cmap = plt.get_cmap('viridis')
colors = [cmap(i) for i in np.linspace(0, 1, first_means.shape[1])]

# Plot the paths of the first 10 means
for i in range(first_means.shape[1]):
    x = first_means[:, i, 0]
    y = first_means[:, i, 1]
    #ax.plot(x, y, c=colors[i], alpha=0.7, linewidth=2)
    #ax.scatter(x[-1], y[-1], c=[colors[i]], s=100, marker='o')  # Highlight the end point
    ax.scatter(x,y, c=colors[i], s=10, alpha=.5)

# Add labels and title
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Paths of First 10 Means During Training')

# Add a colorbar to show the progression
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=9))
sm.set_array([])
cbar = plt.colorbar(sm)
cbar.set_label('Mean Index')

# Set equal aspect ratio and limit the plot to [-1, 1] on both axes
ax.set_aspect('equal')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)

# Add grid lines
ax.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()


In [None]:
import seaborn as sns
# Extract the first 10 stds
first_10_stds = np.array([s['log_stds'][:100,0].detach().numpy() for s in saved_generators])

sns.lineplot(data=first_10_stds)

In [None]:
def run_experiment(
    num_repeats = 10,
    num_features=10,
    
    # Standard training parameters
    num_steps=10000,
    learning_rate=0.001,
    
    # Adaptive training parameters
    adaptive_gen_learning_rate=0.01,
    adaptive_old_generator_prob=0.5,
    apply_scheduler_to_gen=False,

    # Experiment setup parameters
    **setup_kwargs,
):
    """
    Setup args:
    

    """
    all_losses = []
    for repeat_ind in range(num_repeats):
        # Setup the experiment
        teacher_mlp, student_mlp, test_examples = setup_experiment(num_features=num_features, **setup_kwargs)
    
        # Train standard student
        print("Training standard student:")
        student_mlp_standard = copy.deepcopy(student_mlp)
        losses_default, _ = train_student(teacher_mlp, student_mlp_standard, num_features, test_examples,
                      num_steps=num_steps, 
                      learning_rate=learning_rate)
    
        # Train adaptive student
        print("\nTraining adaptive student:")
        student_mlp_adaptive = copy.deepcopy(student_mlp)
        losses_adaptive, _ = train_student(teacher_mlp, student_mlp_adaptive, num_features, test_examples,
                                         num_steps=num_steps, 
                                         learning_rate=learning_rate,
                                         use_input_optimization=True,
                                         gen_learning_rate=adaptive_gen_learning_rate,
                                         old_generator_prob=adaptive_old_generator_prob,
                                         apply_scheduler_to_gen=apply_scheduler_to_gen,
                                         return_saved_generators=False)
        all_losses += [(losses_default, losses_adaptive)]

        
    return all_losses
    
# Example usage:
exp_result = run_experiment(
    num_features=2,
    teacher_hidden_sizes=[32, 32],
    student_hidden_sizes=[64, 64, 32],
    teacher_activation="tanh",
    student_activation="relu",
    teacher_weight_multiplier=8.0,
    num_outputs=2,
    num_steps=10,
    num_test_examples=1000,
    grid_size=3,
    learning_rate=0.001,
    adaptive_gen_learning_rate=0.0001,
    adaptive_old_generator_prob=0.5,
)

In [209]:
from pfns.utils import product_dict

In [None]:
adaptive_learning1 = ex_bosch_cpu.submit_group('adaptive_learning1', run_experiment, list(product_dict({
    'num_features': [2,10],
    'teacher_activation': ['tanh', 'relu'],
    'teacher_weight_multiplier': [1.,2.,4.,8.],
    'num_steps': [1_000, 10_000, 100_000,],
    'adaptive_gen_learning_rate': [.1,.01,.001,.0001],
    'learning_rate': [.003,.001],
    'adaptive_old_generator_prob': [0.,.5,1.],
    'apply_scheduler_to_gen': [False,True]
})))

In [None]:
(adaptive_learning1:=ex.get_group('adaptive_learning1'))

In [11]:
results = []
for j in adaptive_learning1:
    if j.done():
        result = j.result() # a list of tuples each containing a list of default and adaptive losses
        for i, result_type in enumerate(['default', 'adaptive']):
            results.append({
                'loss': sum([seed_result[i][-1] for seed_result in result])/len(result),
                'type': result_type,
                **j.config
            })
import pandas as pd
results_df = pd.DataFrame(results)

In [None]:
results_df

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 10))

filtered_df = results_df[(results_df['num_features'] == 10) & (results_df.type != 'sadaptive') & 
                         (results_df.adaptive_old_generator_prob == 1.) &
                         (results_df.adaptive_gen_learning_rate == .01) &
                         (results_df.learning_rate == .001) & (results_df.teacher_activation == 'relu') &
                         (results_df.teacher_weight_multiplier == 8.) & (results_df.apply_scheduler_to_gen)
]

sns.lineplot(data=filtered_df, x='num_steps', y='loss', style='type', hue='adaptive_gen_learning_rate')
plt.xscale('log')

In [None]:
torch.where(torch.tensor([[0,1],[0,1]]) == 0)