# Dispersion models

In [None]:
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np

from torchlensmaker.materials import *

def plot_dispersion_models(models, wmin=400, wmax=700, labels=None):
    """
    Plot multiple dispersion models from wmin to wmax (in nm)

    Args:
        models: list of DispersionModel objects
        labels: optional list of labels
    """
    W = torch.linspace(wmin, wmax, 1000)
    
    fig, ax = plt.subplots(figsize=(12, 8))

    if labels is None:
        labels = [type(m) for m in models]

    for label, model in zip(labels, models):
        N = model.refractive_index(W)
        ax.plot(W, N, label=label)

    ax.legend()
    ax.set_ylabel("Index of refraction")
    ax.set_xlabel("Wavelength (nm)")


plot_dispersion_models(default_material_models_dict.values(), labels=default_material_models_dict.keys())