In [1]:
# === SETUP ===
import os
import pandas as pd
from joblib import load
from IPython.display import display, clear_output
import ipywidgets as widgets
import requests
from PIL import Image
import matplotlib.pyplot as plt
from math import ceil

## NOTE

ipywidgets has a known bug and widgets do not display in a Jupyter Notebook envrionment.

Linked [here](https://colab.research.google.com/drive/1VOhNRQxL37sbY_RtUvT60ia6Px6RAwit?usp=drive_link) is a working implementation in Google Colaboratory.

In [3]:
# === LOAD MODEL BUNDLE ===

model_path = "models/vgc_regi_restrictedcore_model.joblib"

bundle = load(model_path)

model = bundle["model"]
label_columns = bundle["label_columns"]

In [4]:
df_path = "dataframes"
X_df = pd.read_csv(f"{df_path}/X_df.csv")

In [5]:
# === TEAMMATE PREDICTION UI ===

restricted_pool = sorted([col.replace("core_", "") for col in X_df.columns])

core1_dropdown = widgets.Dropdown(
    options=restricted_pool, description='Core 1:', layout=widgets.Layout(width='300px')
)
core2_dropdown = widgets.Dropdown(
    options=restricted_pool, description='Core 2:', layout=widgets.Layout(width='300px')
)

predict_button = widgets.Button(description='Predict Teammates')
display(core1_dropdown, core2_dropdown, predict_button)

Dropdown(description='Core 1:', layout=Layout(width='300px'), options=('Calyrex', 'Calyrex-Ice', 'Calyrex-Shad…

Dropdown(description='Core 2:', layout=Layout(width='300px'), options=('Calyrex', 'Calyrex-Ice', 'Calyrex-Shad…

Button(description='Predict Teammates', style=ButtonStyle())

In [6]:
restricted_pool

['Calyrex',
 'Calyrex-Ice',
 'Calyrex-Shadow',
 'Cosmoem',
 'Dialga',
 'Dialga-Origin',
 'Eternatus',
 'Giratina',
 'Giratina-Origin',
 'Groudon',
 'Ho-Oh',
 'Koraidon',
 'Kyogre',
 'Kyurem',
 'Kyurem-Black',
 'Kyurem-White',
 'Lugia',
 'Lunala',
 'Mewtwo',
 'Miraidon',
 'Necrozma',
 'Necrozma-Dawn-Wings',
 'Necrozma-Dusk-Mane',
 'Palkia',
 'Palkia-Origin',
 'Rayquaza',
 'Reshiram',
 'Solgaleo',
 'Terapagos',
 'Zacian',
 'Zacian-Crowned',
 'Zamazenta',
 'Zamazenta-Crowned',
 'Zekrom']

In [7]:
# === PREDICT TEAMMATES ===

def predict_teammates(core, model, X_df, label_columns, top_n=20):
    core = list(core)
    input_row = pd.DataFrame(columns=X_df.columns)
    input_row.loc[0] = 0

    for mon in core:
        col = f"core_{mon}"
        if col in input_row.columns:
            input_row.at[0, col] = 1
        else:
            print(f"Warning: {col} not in input features")

    probs = []
    for i, prob_arr in enumerate(model.predict_proba(input_row)):
        if prob_arr.shape[1] == 2:
            probs.append(prob_arr[0, 1])
        else:
            label_idx = model.estimators_[i].classes_[0]
            probs.append(1.0 if label_idx == 1 else 0.0)

    teammate_names = [col.replace("teammate_", "") for col in label_columns]
    results = sorted(zip(teammate_names, probs), key=lambda x: x[1], reverse=True)[:top_n]
    return pd.DataFrame(results, columns=["Teammate", "Predicted Probability"])

In [8]:
# === Retrieve Sprite URL from PokeAPI ===
fallbacks = {
        "ogerpon-cornerstone": "ogerpon-cornerstone-mask",
        "ogerpon-hearthflame": "ogerpon-hearthflame-mask",
        "ogerpon-wellspring": "ogerpon-wellspring-mask",
        "ogerpon": "ogerpon-teal-mask",
        "landorus": "landorus-incarnate",
        "tornadus": "tornadus-incarnate",
        "thundurus": "thundurus-incarnate",
        "enamorus": "enamorus-incarnate",
        "urshifu": "urshifu-single-strike",
        "indeedee-f": "indeedee-female",
        "giratina" : "giratina-altered"
    }

def get_sprite_url(pokemon_name):
    base_name = pokemon_name.lower().replace(" ", "-").replace("’", "").replace("'", "")
    name_attempts = []
    if base_name in fallbacks:
        name_attempts.append(fallbacks[base_name])
    name_attempts.append(base_name)
    name_attempts.append(base_name.split("-")[0])

    for attempt in name_attempts:
        try:
            res = requests.get(f"https://pokeapi.co/api/v2/pokemon/{attempt}")
            res.raise_for_status()
            data = res.json()

            sprite_url = (
                data["sprites"]["front_default"]
                or data["sprites"]["other"]["official-artwork"]["front_default"]
            )

            if sprite_url:
                return sprite_url
        except Exception:
            continue

    # Return transparent fallback instead of None
    return "https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites/pokemon/0.png"

In [9]:
def on_predict_clicked(b):
    print("Button clicked")
    clear_output(wait=True)
    display(core1_dropdown, core2_dropdown, predict_button)

    core_pair = (core1_dropdown.value, core2_dropdown.value)
    print(f"\nPredicting teammates for: {core_pair[0]} + {core_pair[1]}")

    results = predict_teammates(core_pair, model, X_df, label_columns)

    print("\nRestricted Core:")
    display_restricted_pair(core_pair)

    print("\nPredicted Teammates:")
    show_sprite_grid(results)

# Reset handler
predict_button._click_handlers.callbacks.clear()
predict_button.on_click(on_predict_clicked)


In [10]:
from IPython.display import HTML

def display_restricted_pair(core_pair):
    urls = [get_sprite_url(name) for name in core_pair]
    html = "<div style='display:flex; gap:20px; align-items:center; margin-bottom:10px;'>"
    for url in urls:
        if url:
            html += f"<img src='{url}' style='height:128px;'>"
    html += "</div>"
    display(HTML(html))


In [11]:
# === FINAL OUTPUT DISPLAY ===

def show_sprite_grid(results, per_row=None):
    if per_row is None:
        per_row = ceil(len(results) ** 0.5)  # Make the grid square-ish

    html = "<table style='margin-top:10px'><tr>"
    count = 0
    filtered = [(name, prob) for name, prob in results.values if prob > 0]
    for name, prob in filtered:
        sprite_url = get_sprite_url(name)
        if not sprite_url:
            continue  # skip missing

        html += f"""
        <td style="text-align:center; padding:10px;">
            <img src="{sprite_url}" style="height:96px;"><br>
            <span>{name} - {prob:.1%}</span>
        </td>
        """
        count += 1
        if count % per_row == 0:
            html += "</tr><tr>"

    html += "</tr></table>"
    display(HTML(html))

> **Generative AI Disclaimer:**

Generative AI (Gemini, ChatGPT) was used to build the scaffolding for helper functions, for example the HTML formatting.