# Firework Simulator
Click to spawn a firework! The longer you hold, the larger the generated firework will be (up to 1.5s).

There are currently four types of fireworks enabled. See which one you like the most. :)

In [None]:
import pygame
import pygame_menu
import random
import numpy as np
from pygame.math import Vector2
from queue import Queue
from numpy.random import choice

SCREEN_WIDTH = 1700
SCREEN_HEIGHT = int(SCREEN_WIDTH * 10/16)
SCREEN_SIZE = (SCREEN_WIDTH, SCREEN_HEIGHT)
FPS = 60

EXPLOSION_TYPES = None
EXPLOSION_WEIGHTS = None
RECURSIVE_EXPLOSION_WEIGHTS = None

screen, clock, menu = None, None, None

pygame.init()

def lerp(a, b, t, type=float):
    # clamp t between 0 and 1
    t = max(0, min(t, 1))
    return type(a + (b - a) * t)

def color_lerp(color1, color2, t):
    return (lerp(color1[0], color2[0], t),
            lerp(color1[1], color2[1], t),
            lerp(color1[2], color2[2], t))

def vertical_gradient(size, startcolor, endcolor):
    """
    Draws a vertical linear gradient filling the entire surface. Returns a
    surface filled with the gradient (numeric is only 2-3 times faster).
    """
    height = size[1]
    bigSurf = pygame.Surface((1,height)).convert_alpha()
    dd = 1.0/height

    if len(startcolor) == 3:
        sr, sg, sb = startcolor
        er, eg, eb = endcolor
        sa, ea = 255, 255
    elif len(startcolor) == 4:
        sr, sg, sb, sa = startcolor
        er, eg, eb, ea = endcolor

    rm = (er-sr)*dd
    gm = (eg-sg)*dd
    bm = (eb-sb)*dd
    am = (ea-sa)*dd
    
    for y in range(height):
        bigSurf.set_at((0,y),
                        (int(sr + rm*y),
                         int(sg + gm*y),
                         int(sb + bm*y),
                         int(sa + am*y))
                      )
    return pygame.transform.scale(bigSurf, size)

def draw_line_with_alpha(screen, color, pos1, pos2, size, alpha):
    # create a new surface with the right size
    width = pos2[0] - pos1[0]
    height = pos2[1] - pos1[1]
    line_surface = pygame.Surface((abs(width) + 2 * size, abs(height) + 2 * size), pygame.SRCALPHA)

    # draw the line
    if width * height > 0:
        start = (size, size)
        end = (size + abs(width), size + abs(height))
    else:
        start = (size + abs(width), size)
        end = (size, size + abs(height))

    pygame.draw.line(line_surface, color, start, end, size)

    # set the alpha of the surface
    line_surface.set_alpha(alpha)

    # draw the surface onto the screen
    x = min(pos1[0], pos2[0]) - size
    y = min(pos1[1], pos2[1]) - size
    screen.blit(line_surface, (x, y))

def draw_circle_with_alpha(screen, color, pos, radius, alpha):
    print(f'drawing with alpha {alpha}')
    # create a new surface with the right size
    circle_surface = pygame.Surface((2 * radius, 2 * radius), pygame.SRCALPHA)

    # draw the circle
    pygame.draw.circle(circle_surface, color, (radius, radius), radius)

    # set the alpha of the surface
    circle_surface.set_alpha(alpha)

    # draw the surface onto the screen
    x = pos[0] - radius
    y = pos[1] - radius
    screen.blit(circle_surface, (x, y))

def random_unit():
    return Vector2(1, 0).rotate(random.randrange(0, 360))

def color_ping_pong(color1, color2, t, length):
    t = t % length
    if t < length / 2:
        return color_lerp(color1, color2, t / length)
    else:
        return color_lerp(color2, color1, (t - length / 2) / length)

In [None]:
class ScreenProfile:
    def __init__(self, minX, minY, maxX, maxY, screen_size):
        self.minX = minX
        self.minY = minY
        self.maxX = maxX
        self.maxY = maxY
        self.midX = (maxX + minX) / 2
        self.midY = (maxY + minY) / 2
        self.screen_size = screen_size
    
    def to_screen(self, pos):
        x = (pos.x - self.minX) / (self.maxX - self.minX) * self.screen_size[0]

        # flip y axis
        y = (pos.y - self.minY) / (self.maxY - self.minY) * self.screen_size[1]
        y = self.screen_size[1] - y
        return Vector2(int(x), int(y))

    def to_world(self, pos):
        x = (pos.x / self.screen_size[0]) * (self.maxX - self.minX) + self.minX
        
        # flip y axis
        y = (pos.y / self.screen_size[1]) * (self.maxY - self.minY) + self.minY
        y = self.maxY - y
        return Vector2(x, y)

    def point_on_screen(self, pos):
        return self.minX <= pos.x <= self.maxX and self.minY <= pos.y <= self.maxY

class BasicParticle():
    def __init__(self, position, size, color, lifetime, velocity):
        self.position = position
        self.size = size
        self.color = color
        self.velocity = velocity
        self.lifetime = lifetime
        self.alive = True
        self.gravity = Vector2(0, -9.8)
        self.air_resistance = 0.5

    def update(self, delta_time):
        if self.alive:
            # update velocity
            self.velocity += self.gravity * delta_time
            self.velocity *= 1 - (self.air_resistance * delta_time)

            # update position
            self.position += self.velocity * delta_time

            # update lifetime
            self.lifetime -= delta_time
            if self.lifetime <= 0:
                self.alive = False

    def draw(self, screen, screen_profile):
        if self.alive:
            pos = screen_profile.to_screen(self.position)
            pygame.draw.circle(screen, self.color, pos, self.size)


class TrailParticle(BasicParticle):
    def __init__(self, position, size, color, lifetime, velocity, trail_length):
        super().__init__(position, size, color, lifetime, velocity)

        self.gravity = Vector2(0, -1 * 9.8)
        self.trail = [Vector2(self.position)]
        self.trail_length = trail_length

    def update(self, delta_time):
        if self.alive:
            # update velocity
            self.velocity += self.gravity * delta_time
            self.velocity *= 1 - (self.air_resistance * delta_time)

            # update position
            self.position += self.velocity * delta_time

            # update lifetime and trail
            self.lifetime -= delta_time
            self.trail.insert(0, Vector2(self.position))

            # make trail disappear
            if self.lifetime <= 0:
                self.trail.pop()
                self.trail.pop()
                if len(self.trail) == 0:
                    self.alive = False
            elif len(self.trail) > self.trail_length:
                self.trail.pop()

    def draw(self, screen, screen_profile):
        if self.alive:
            # check if either end of the trail is on the screen
            has_trail = (len(self.trail) > 1)
            if not has_trail:
                return

            first_point_visible = screen_profile.point_on_screen(self.trail[0])
            last_point_visible = screen_profile.point_on_screen(self.trail[-1])

            if first_point_visible or last_point_visible:
                # draw trail
                l = len(self.trail)
                step = l // 4 + 1
                if l < step:
                    step = 1
                for i in range(0, l, step):
                    if i >= l - step:
                        i = l - step - 1
                        
                    j = i + step
                    alpha = lerp(255, 40, j / l)
                    pos1 = screen_profile.to_screen(self.trail[i])
                    pos2 = screen_profile.to_screen(self.trail[j])                
                    draw_line_with_alpha(screen, self.color, pos1, pos2, int(self.size), alpha)


class WeirdParticle(BasicParticle):
    SPEED_INCREASE = 3

    def __init__(self, position, size, color, lifetime, velocity, switch_time):
        super().__init__(position, size, color, lifetime, velocity)
        # switch_time between 0 and 1

        # set self.switch_time to the lifetime remaining at which to switch
        self.switch_time = self.lifetime * (1 - switch_time)

    def update(self, delta_time):
        if self.alive:
            # update velocity
            self.velocity += self.gravity * delta_time
            self.velocity *= 1 - (self.air_resistance * delta_time)

            # update position
            self.position += self.velocity * delta_time

            # update lifetime
            self.lifetime -= delta_time
            if self.lifetime <= self.switch_time:
                self.velocity = WeirdParticle.SPEED_INCREASE * random_unit() * self.velocity.length()
                self.switch_time = -1
            elif self.lifetime <= 0:
                self.alive = False


class TwinklingParticle(BasicParticle):
    def __init__(self, position, size, color1, color2, lifetime, velocity, twinkle_speed):
        super().__init__(position, size, color1, lifetime, velocity)

        self.color2 = color2
        self.twinkle_speed = twinkle_speed

    def draw(self, screen, screen_profile):
        if self.alive:
            pos = screen_profile.to_screen(self.position)
            color = color_ping_pong(self.color, self.color2, self.lifetime, self.twinkle_speed)
            pygame.draw.circle(screen, color, pos, self.size)


class FadingParticle(BasicParticle):
    def __init__(self, position, size, color, lifetime, fade_start_time, velocity):
        super().__init__(position, size, color, lifetime, velocity)
        # fade_time between 0 and 1

        # set self.fade_start_time to the lifetime remaining at which to start fading
        self.fade_start_time = self.lifetime * fade_start_time
        self.fade_duration = self.lifetime - self.fade_start_time
        
        self.gravity = Vector2(0, 0.01 * -9.8)
    
    def draw(self, screen, screen_profile):
        if self.alive:
            pos = screen_profile.to_screen(self.position)
            
            if self.lifetime <= self.fade_start_time:
                alpha = lerp(0, 255, self.lifetime / self.fade_duration)
                draw_circle_with_alpha(screen, self.color, pos, self.size, alpha)
            else:
                super().draw(screen, screen_profile)


class SparklingTrailParticle(BasicParticle):
    def __init__(self, position, size, color, lifetime, velocity, trail_density):
        super().__init__(position, size, color, lifetime, velocity)

        self.gravity = Vector2(0, -2 * 9.8)
        self.trail_density = trail_density
    
    def update(self, delta_time):
        super().update(delta_time)
        if self.alive:
            # spawn trail_density * delta_time fading particles
            num_new_particles = int(self.trail_density * delta_time)
            fade_start_time = 0
            return [FadingParticle(self.position, self.size, self.color, self.lifetime, fade_start_time, (0, 0)) for _ in range(num_new_particles)]

    def draw(self, screen, screen_profile):
        # if self.alive:
        #     # random number between 0 and 255
        #     alpha = random.randint(0, 255)
        #     draw_circle_with_alpha(screen, self.color, screen_profile.to_screen(self.position), self.size, alpha)
        pass

class ExplodingParticle(BasicParticle):
    def __init__(self, position, size, color, lifetime, velocity, trail_length, explosion_type, explosion_intensity):
        super().__init__(position, size, color, lifetime, velocity)

        self.explosion_type = explosion_type
        self.explosion_intensity = explosion_intensity

        self.gravity = Vector2(0, -1 * 9.8)
        self.trail = [Vector2(self.position)]
        self.trail_length = trail_length
    
    def create_firework(self):
        return self.explosion_type.random(self.position, self.explosion_intensity)

    def update(self, delta_time):
        if self.alive:
            # update velocity
            self.velocity += self.gravity * delta_time
            self.velocity *= 1 - (self.air_resistance * delta_time)

            # update position
            self.position += self.velocity * delta_time

            # update lifetime and trail
            self.lifetime -= delta_time
            self.trail.insert(0, Vector2(self.position))

            # make trail disappear
            if self.lifetime <= len(self.trail) * delta_time:
                try:
                    self.trail.pop()
                    self.trail.pop()
                except IndexError:
                    pass
            elif len(self.trail) > self.trail_length:
                self.trail.pop()

            if self.lifetime <= 0:
                self.alive = False
                return self.create_firework()

    def draw(self, screen, screen_profile):
        if self.alive:
            # check if either end of the trail is on the screen
            has_trail = (len(self.trail) > 1)
            if not has_trail:
                return

            first_point_visible = screen_profile.point_on_screen(self.trail[0])
            last_point_visible = screen_profile.point_on_screen(self.trail[-1])

            if first_point_visible or last_point_visible:
                # draw particle
                pos = screen_profile.to_screen(self.position)
                pygame.draw.circle(screen, self.color, pos, self.size)

                # draw trail
                l = len(self.trail)
                step = l // 8 + 1
                if l < step:
                    step = 1
                for i in range(0, l, step):
                    if i >= l - step:
                        i = l - step - 1
                        
                    j = i + step
                    alpha = lerp(255, 40, j / l)
                    pos1 = screen_profile.to_screen(self.trail[i])
                    pos2 = screen_profile.to_screen(self.trail[j])                
                    draw_line_with_alpha(screen, self.color, pos1, pos2, int(self.size), alpha)


class LaunchParticle(ExplodingParticle):
    def __init__(self, position, size, color, lifetime, trail_length, firework, target_position):
        super().__init__(position, size, color, lifetime, None, trail_length, None, None)
        self.firework = firework
        self.air_resistance = 0
        self.gravity = Vector2(0, -3 * 9.8)

        gravity = self.gravity
        initial_position = self.position

        # calculate the velocity needed to reach the target position at end of lifetime
        # p_target = 1/2 * gravity * lifetime^2 + v_initial * lifetime + p_initial
        # v_initial = (p_target - p_initial - 1/2 * gravity * lifetime^2) / lifetime

        v_initial = (target_position - initial_position - 1/2 * gravity * lifetime ** 2) / lifetime
        self.velocity = v_initial
    
    def create_firework(self):
        return self.firework


class Firework:
    def __init__(self, components):
        self.components = components
        self.alive = True
    
    def is_alive(self):
        return any(component.alive for component in self.components)

    def update(self, delta_time):
        for particle in self.components:
            particle.update(delta_time)

        self.alive = self.is_alive()
        
    def draw(self, screen, screen_profile):
        for component in self.components:
            component.draw(screen, screen_profile)

class BasicFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, colors, lifetime, speed):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(1, max_particle_size)

            # generate a random color
            color = random.choice(colors)

            # create the particle
            particle = BasicParticle(particle_position, particle_size, color, particle_lifetime, velocity)
            particles.append(particle)

        super().__init__(particles)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(50, 2500, t, type=int)
        max_particle_size = lerp(2, 6, t)
        lifetime = lerp(2, 15, t)
        speed = lerp(5, 250, t, type=int)

        num_colors = np.random.randint(1, 4)
        colors = [(random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255)) for i in range(num_colors)] 
        
        # print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = BasicFirework(position, num_particles, max_particle_size, colors, lifetime, speed)
        return firework


class TrailFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, colors, lifetime, speed, trail_length):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(1, max_particle_size)

            # generate a random trail length
            particle_trail_length = random.uniform(trail_length/3, trail_length)

            # generate a random color
            color = random.choice(colors)

            # create the particle
            particle = TrailParticle(particle_position, particle_size, color, particle_lifetime, velocity, particle_trail_length)
            particles.append(particle)

        super().__init__(particles)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(50, 1000, t, type=int)
        max_size = lerp(2, 6, t)
        lifetime = lerp(2, 15, t)
        speed = lerp(25, 200, t, type=int)
        # trail_length = lerp(50, 30, t, type=int)
        trail_length = 40

        num_colors = np.random.randint(1, 3)
        colors = [(random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255)) for i in range(num_colors)] 

        print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = TrailFirework(position, num_particles, max_size, colors, lifetime, speed, trail_length)
        return firework

class SparklingTrailFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, colors, lifetime, speed, trail_density):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(1, max_particle_size)

            # generate a random trail density
            # particle_trail_density = random.uniform(trail_density/3, trail_density)

            # generate a random color
            color = random.choice(colors)

            # create the particle
            particle = SparklingTrailParticle(particle_position, particle_size, color, particle_lifetime, velocity, trail_density)
            particles.append(particle)

        super().__init__(particles)
    
    def update(self, delta_time):
        for i in range(len(self.components)):
            result = self.components[i].update(delta_time)
            # check if the result is a firework or a particle
            if result is not None:
                self.components.extend(result)
            
        self.alive = any(component.alive for component in self.components)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(5, 10, t, type=int)
        max_size = lerp(5, 10, t)
        lifetime = lerp(6, 15, t)
        speed = lerp(20, 250, t, type=int)
        trail_density = 100

        # num_colors = np.random.randint(1, 3)
        num_colors = 1
        colors = [(random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255)) for i in range(num_colors)] 

        print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = SparklingTrailFirework(position, num_particles, max_size, colors, lifetime, speed, trail_density)
        return firework


class WeirdFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, colors, lifetime, speed, switch_time):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(2 * lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(1, max_particle_size)

            # generate a random color
            color = random.choice(colors)

            # create the particle
            particle = WeirdParticle(particle_position, particle_size, color, particle_lifetime, velocity, switch_time)
            particles.append(particle)

        super().__init__(particles)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(50, 250, t, type=int)
        max_particle_size = lerp(3, 8, t)
        lifetime = lerp(6, 12, t)
        speed = lerp(15, 60, t, type=int)
        switch_time = lerp(0.25, 0.125, t)

        num_colors = np.random.randint(1, 3)
        colors = [(random.uniform(180, 255), random.uniform(180, 255), random.uniform(180, 255)) for i in range(num_colors)] 


        print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = WeirdFirework(position, num_particles, max_particle_size, colors, lifetime, speed, switch_time)
        return firework


class TwinklingFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, color1, color2, lifetime, speed, twinkle_speed):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(1, max_particle_size)

            # create the particle
            particle = TwinklingParticle(particle_position, particle_size, color1, color2, particle_lifetime, velocity, twinkle_speed)
            particles.append(particle)

        super().__init__(particles)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(50, 2500, t, type=int)
        max_particle_size = lerp(2, 6, t)
        lifetime = lerp(2, 15, t)
        speed = lerp(5, 250, t, type=int)

        color1 = (random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255))
        color2 = (random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255))
        twinkling_speed = random.uniform(0.2, 0.5)

        print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = TwinklingFirework(position, num_particles, max_particle_size, color1, color2, lifetime, speed, twinkling_speed)
        return firework


class RecursiveFirework(Firework):
    def __init__(self, position, num_particles, max_particle_size, colors, lifetime, speed, trail_length, explosion_type, explosion_intensity):
        particles = []
        
        # spawn random particles
        for i in range(num_particles):
            # copy position
            particle_position = Vector2(position)

            # generate a random velocity
            particle_speed = random.uniform(0, speed)
            velocity = particle_speed * random_unit()

            # generate a random lifetime
            particle_lifetime = random.uniform(lifetime/3, lifetime)

            # generate a random size
            particle_size = random.uniform(3, max_particle_size)

            # generate a random color
            color = random.choice(colors)

            # generate a random intensity
            intensity = random.uniform(0.5, 1) * explosion_intensity

            # create the particle
            particle = ExplodingParticle(particle_position, particle_size, color, particle_lifetime, velocity, trail_length, explosion_type, intensity)
            particles.append(particle)

        super().__init__(particles)
    
    def update(self, delta_time):
        for i in range(len(self.components)):
            result = self.components[i].update(delta_time)
            # check if the result is a firework or a particle
            if result is not None:
                self.components[i] = result
            
        self.alive = any(component.alive for component in self.components)

    # return a random firework with a given explosive parameter t between 0 and 1
    def random(position, t):
        # lerp properties of the firework
        num_particles = lerp(3, 15, t, type=int)
        max_particle_size = lerp(4, 12, t)
        lifetime = random.uniform(1, 3)
        speed = lerp(40, 100, t, type=int)
        trail_length = 75

        num_colors = np.random.randint(1, 4)
        colors = [(random.uniform(150, 255), random.uniform(150, 255), random.uniform(150, 255)) for i in range(num_colors)] 

        explosion_type = choice(EXPLOSION_TYPES, p=RECURSIVE_EXPLOSION_WEIGHTS)
        explosion_intensity = random.uniform(0.1, 0.3) * t

        print(f'Generating firework with {num_particles} particles, lifetime {lifetime}, speed {speed}')
        firework = RecursiveFirework(position, num_particles, max_particle_size, colors, lifetime, speed, trail_length, explosion_type, explosion_intensity)
        return firework

EXPLOSION_TYPES = [BasicFirework, TrailFirework, SparklingTrailFirework, WeirdFirework, RecursiveFirework]
# EXPLOSION_WEIGHTS = [0, 0, 0, 0, 1]
EXPLOSION_WEIGHTS = [0.3, 0.3, 0, 0.2, 0.2]
# EXPLOSION_WEIGHTS = [0.2, 0.2, 0.2, 0.2, 0.2]
RECURSIVE_EXPLOSION_WEIGHTS = [0.4, 0.2, 0, 0.3, 0.1]

In [None]:
def run_simulation(screen_profile):
    global screen, clock

    fireworks = []
    mouse_down_time = 0

    background_gradient = vertical_gradient(screen_profile.screen_size, (0, 0, 25), (35, 35, 60))

    while True:
        # update fireworks
        for firework in fireworks:
            result = firework.update(1/FPS)
            if result is not None:
                fireworks.append(result)

        # remove dead fireworks
        fireworks = [firework for firework in fireworks if firework.alive]
        
        # draw black background
        screen.blit(background_gradient, (0, 0))

        # draw fireworks
        for firework in fireworks:
            firework.draw(screen, screen_profile)

        # update screen
        pygame.display.flip()
        clock.tick(FPS)

        # handle events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                return
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_ESCAPE:
                    return

            # on click, create firework
            if event.type == pygame.MOUSEBUTTONDOWN:
                mouse_down_time = pygame.time.get_ticks()
            if event.type == pygame.MOUSEBUTTONUP:
                time = pygame.time.get_ticks() - mouse_down_time

                # get the explosion intensity
                max_time = 1500
                t = time / max_time

                # get the position
                position = Vector2(pygame.mouse.get_pos())
                position = screen_profile.to_world(position)

                # generate a random firework
                firework_type = choice(EXPLOSION_TYPES, p=EXPLOSION_WEIGHTS)
                firework = firework_type.random(position, t)

                launch_duration = lerp(1.5, 0.6, t)
                launch_size = lerp(3, 8, t)
                launch_trail_length = 20
                launch_particle = LaunchParticle((screen_profile.midX, screen_profile.minY), launch_size, (225, 225, 225), launch_duration, launch_trail_length, firework, position)
                
                fireworks.append(launch_particle)

            # if event.type == pygame.MOUSEBUTTONDOWN:
            #     if event.button == 4:
            #         pygame.display.set_mode(screen_profile.screen_size, pygame.RESIZABLE)
            #     elif event.button == 5:
            #         pygame.display.set_mode(screen_profile.screen_size, pygame.NOFRAME)
            # if event.type == pygame.VIDEORESIZE:
            #     screen = pygame.display.set_mode(event.size, pygame.RESIZABLE)
            #     screen_profile = ScreenProfile(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT, event.size)

In [None]:
# start game
def main():
    global screen, clock, menu

    width = 240
    height = width * SCREEN_HEIGHT / SCREEN_WIDTH
    screen_profile = ScreenProfile(-width // 2, 0, width // 2, height, SCREEN_SIZE)
    
    # initialize pygame
    pygame.init()
    pygame.font.init()

    pygame.display.set_caption('Fireworks')
    screen = pygame.display.set_mode(SCREEN_SIZE)
    clock = pygame.time.Clock()

    # initialize pygame_menu
    menu = pygame_menu.Menu('Fireworks', SCREEN_WIDTH, SCREEN_HEIGHT, theme=pygame_menu.themes.THEME_DARK)

    # # set an icon for the game
    # icon = pygame.image.load('images/airplane.png')
    # pygame.display.set_icon(icon)

    # create start button
    menu.add.button('Start', run_simulation, screen_profile, font_size=30)

    # # create controls button
    # menu.add.button('Controls', view_controls, font_size=30)

    # create quit button
    menu.add.button('Quit', pygame_menu.events.EXIT, font_size=30)

    # main loop
    while True:
        # handle events
        events = pygame.event.get()
        for event in events:
            if event.type == pygame.QUIT:
                return
            elif event.type == pygame_menu.events.BACK:
                return

        # update menu
        menu.update(events)

        # draw menu
        menu.draw(screen)

        pygame.display.update()
        clock.tick(FPS)

In [None]:
if __name__ == '__main__':
    main()