In [None]:
import json
import random
import numpy as np
from collections import defaultdict
from itertools import combinations

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
CARDS_JSON = "data/labeled_cards.json"
DECKS_JSON = "data/one_piece_deck_lists.json"
KNOWN_SYNERGIES_JSON = "graphs/card_synergies.json"  # new synergy file

with open(CARDS_JSON, "r") as f:
    cards_data = json.load(f)
with open(DECKS_JSON, "r") as f:
    decks_data = json.load(f)
with open(KNOWN_SYNERGIES_JSON, "r") as f:
    known_synergies_data = json.load(f)

print("Loaded", len(cards_data), "cards,",
      len(decks_data), "decks, and a known synergy file.")

BANNED_CARDS = {
    "ST10-001", "OP03-098", "OP05-041", "ST06-015",
    "OP06-116", "OP2-024", "OP02-052"
}

card_id_to_index = {c["id"]: i for i, c in enumerate(cards_data)}
card_index_to_id = {i: c["id"] for i, c in enumerate(cards_data)}

leaders = [c for c in cards_data if c["type"] == "Leader"]

In [None]:
def gather_all_traits(cards_data):
    traits_set = set()
    for c in cards_data:
        for t in c.get("traits", []):
            traits_set.add(t)
    return sorted(traits_set)

ALL_TRAITS = gather_all_traits(cards_data)
trait2idx = {t: i for i, t in enumerate(ALL_TRAITS)}
NUM_TRAITS = len(trait2idx)
print("Number of unique traits:", NUM_TRAITS)

def multi_hot_trait_vector(card):
    vec = np.zeros(NUM_TRAITS, dtype=np.float32)
    for t in card.get("traits", []):
        if t in trait2idx:
            vec[trait2idx[t]] = 1.0
    return vec

In [None]:
def get_card_colors(card):
    ccol = card.get("color", [])
    if isinstance(ccol, str):
        ccol = [ccol]
    return ccol

def find_leader_in_deck(deck):
    for entry in deck:
        cid = list(entry.keys())[0]
        if cid in [l["id"] for l in leaders]:
            return cid
    return None

def deck_archetype(deck):
    lid = find_leader_in_deck(deck)
    if not lid:
        return ""
    leader_card = next((l for l in leaders if l["id"] == lid), None)
    if leader_card:
        return leader_card.get("archetype","").strip()
    return ""

def deck_matches_color_archetype(deck, leader_color, leader_arche):
    lid = find_leader_in_deck(deck)
    if not lid:
        return False
    leader_card = next((c for c in cards_data if c["id"] == lid), None)
    if not leader_card:
        return False
    deck_colors = get_card_colors(leader_card)
    deck_arch = leader_card.get("archetype","").strip().lower()

    color_ok = any(c in deck_colors for c in leader_color)
    if not color_ok:
        return False
    if leader_arche:
        return (deck_arch == leader_arche.lower())
    return True

In [None]:
def parse_known_synergies(known_synergies_data, synergy_weight=5.0, treat_as_undirected=True):
    """
    Returns a dict of edges: synergy_edges[(source_id, target_id)] = synergy_weight
    known_synergies_data is the JSON object with "nodes" and "links".
    
    synergy_weight: base weight to add for each known synergy link
    treat_as_undirected: if True, we create edges both ways.
    """
    synergy_edges = defaultdict(float)
    # nodes array: we can ignore except for ID references
    # links array: we parse the synergy
    links = known_synergies_data.get("links", [])
    for ln in links:
        src = ln["source"]
        tgt = ln["target"]
        # label = ln["label"] # can use if you want different synergy per label

        if src in card_id_to_index and tgt in card_id_to_index:
            synergy_edges[(src, tgt)] += synergy_weight
            if treat_as_undirected:
                synergy_edges[(tgt, src)] += synergy_weight

    return synergy_edges


In [None]:
def gather_decks_for_color_archetype(leader_color, leader_arche, all_decks):
    filtered = [
        d for d in all_decks
        if deck_matches_color_archetype(d, leader_color, leader_arche)
    ]
    if filtered:
        return filtered
    
    # fallback: color only
    def deck_matches_color(deck, chosen_color):
        lid = find_leader_in_deck(deck)
        if not lid:
            return False
        lead_card = next((c for c in cards_data if c["id"]==lid), None)
        if not lead_card:
            return False
        lead_colors = get_card_colors(lead_card)
        return any(x in lead_colors for x in chosen_color)

    fallback = [d for d in all_decks if deck_matches_color(d, leader_color)]
    return fallback

def build_synergy_graph(
    relevant_decks, 
    leader_arche, 
    synergy_weight=2.0, 
    known_synergy_edges=None
):
    """
    1) Build an undirected synergy graph from deck co-occurrence.
    2) If deck archetype == leader_arche, edges get synergy_weight multiplier.
    3) Then incorporate known synergy edges if provided.
    """
    G = nx.Graph()
    all_ids = {c["id"] for c in cards_data}
    G.add_nodes_from(all_ids)

    for deck in relevant_decks:
        lid = find_leader_in_deck(deck)
        if not lid:
            continue
        deck_arch = deck_archetype(deck)
        w_factor = synergy_weight if deck_arch.lower() == leader_arche.lower() else 1.0

        deck_cards = []
        for entry in deck:
            cid = list(entry.keys())[0]
            cnt = entry[cid]
            if cid == lid:
                continue
            if cid in all_ids and cnt>0:
                deck_cards.extend([cid]*cnt)

        # add edges from co-occurrence
        for a, b in combinations(deck_cards, 2):
            if a == b:
                continue
            if G.has_edge(a, b):
                G[a][b]["weight"] += w_factor
            else:
                G.add_edge(a, b, weight=w_factor)

    # incorporate known synergy edges
    if known_synergy_edges:
        for (src, tgt), w in known_synergy_edges.items():
            # skip if banned or not in dataset
            if src not in all_ids or tgt not in all_ids:
                continue
            if src in BANNED_CARDS or tgt in BANNED_CARDS:
                continue
            if G.has_edge(src, tgt):
                G[src][tgt]["weight"] += w
            else:
                G.add_edge(src, tgt, weight=w)

    return G


In [None]:
def compute_card_frequency(relevant_decks):
    freq_map = defaultdict(int)
    all_ids = {c["id"] for c in cards_data}
    for d in relevant_decks:
        lid = find_leader_in_deck(d)
        for entry in d:
            cid = list(entry.keys())[0]
            cnt = entry[cid]
            if cid in all_ids and cid!= lid:
                freq_map[cid]+= cnt
    return freq_map

def compute_archetype_usage_ratio(relevant_decks, leader_arche):
    from collections import defaultdict
    all_ids= {c["id"] for c in cards_data}
    arche_decks= [d for d in relevant_decks if deck_archetype(d).lower()== leader_arche.lower()]
    deck_count= len(arche_decks)
    usage_count= defaultdict(int)
    if deck_count==0:
        return {c["id"]: 0.0 for c in cards_data}

    for d in arche_decks:
        cset= set()
        lid = find_leader_in_deck(d)
        for entry in d:
            cid = list(entry.keys())[0]
            cnt = entry[cid]
            if cid!= lid and cnt>0 and cid in all_ids:
                cset.add(cid)
        for cid in cset:
            usage_count[cid]+=1

    ratio_map = {}
    for c in cards_data:
        cid = c["id"]
        ratio = usage_count[cid]/deck_count if deck_count>0 else 0.0
        ratio_map[cid] = ratio
    return ratio_map

In [None]:
def build_node_feature_with_traits(card, archetype_usage_map):
    cost_val = float(card.get("cost", 0.0))
    power_val= float(card.get("power", 0.0))
    ctr_val  = float(card.get("counter", 0.0))
    arch_use = archetype_usage_map.get(card["id"], 0.0)

    trait_vec = multi_hot_trait_vector(card)
    base_features = np.array([cost_val, power_val, ctr_val, arch_use], dtype=np.float32)
    return np.concatenate((base_features, trait_vec))

def build_pyg_data(cards_data, synergy_graph, freq_map, archetype_usage_map):
    edges = []
    weights = []
    node_features = []
    usage_list = []

    for c in cards_data:
        cid = c["id"]
        feat = build_node_feature_with_traits(c, archetype_usage_map)
        node_features.append(feat)
        usage_list.append(freq_map.get(cid, 0))

    x_tensor = torch.tensor(node_features, dtype=torch.float)
    node_labels = np.array(usage_list, dtype=np.float32)

    for u, v, data_attr in synergy_graph.edges(data=True):
        if u in card_id_to_index and v in card_id_to_index:
            ui = card_id_to_index[u]
            vi = card_id_to_index[v]
            w  = data_attr["weight"]
            edges.append([ui, vi])
            edges.append([vi, ui])
            weights.append(w)
            weights.append(w)

    if edges:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_weight= torch.tensor(weights, dtype=torch.float)
    else:
        edge_index = torch.empty((2,0), dtype=torch.long)
        edge_weight= torch.empty((0,), dtype=torch.float)

    data_pyg = Data(x=x_tensor, edge_index=edge_index)
    data_pyg.edge_weight= edge_weight
    return data_pyg, node_labels

In [None]:
class MultiTaskGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=64, emb_dim=32, dropout=0.3, num_layers=3):
        super().__init__()
        self.num_layers= num_layers
        self.dropout= torch.nn.Dropout(dropout)

        self.conv1= GCNConv(in_dim, hidden_dim)
        self.conv2= GCNConv(hidden_dim, hidden_dim)
        self.conv3= GCNConv(hidden_dim, emb_dim)

        self.node_head= torch.nn.Linear(emb_dim, 1)
        self.synergy_mlp= torch.nn.Sequential(
            torch.nn.Linear(2*emb_dim, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1)
        )

    def forward(self, data):
        x, edge_index= data.x, data.edge_index
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.dropout(h)

        h = self.conv2(h, edge_index)
        h = F.relu(h)
        h = self.dropout(h)

        h = self.conv3(h, edge_index)
        h = F.relu(h)
        h = self.dropout(h)

        node_score = self.node_head(h).squeeze(-1)
        return h, node_score

    def predict_synergy(self, hu, hv):
        inp = torch.cat([hu, hv], dim=-1)
        return self.synergy_mlp(inp).squeeze(-1)

In [None]:
def sample_negative_edges(edge_index, num_nodes, num_neg):
    existing= set()
    E= edge_index.shape[1]
    for i in range(E):
        u= edge_index[0, i].item()
        v= edge_index[1, i].item()
        existing.add((u,v))

    neg= []
    attempts= 0
    while len(neg)< num_neg and attempts<2_000_000:
        attempts+=1
        a= random.randint(0, num_nodes-1)
        b= random.randint(0, num_nodes-1)
        if a==b:
            continue
        if (a,b) in existing:
            continue
        neg.append((a,b))
    return neg

def train_gnn(model, data, node_labels_np, epochs=100, lr=0.01, neg_factor=2.0):
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model= model.to(device)
    data= data.to(device)

    optimizer= torch.optim.Adam(model.parameters(), lr=lr)
    mean_val= node_labels_np.mean()
    std_val = node_labels_np.std() if node_labels_np.std()>0 else 1.0
    node_labels_norm= (node_labels_np- mean_val)/ std_val
    node_labels_t= torch.tensor(node_labels_norm, dtype=torch.float, device=device)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        node_emb, node_score= model(data)
        loss_node= F.mse_loss(node_score, node_labels_t)

        # synergy
        pos_u, pos_v, pos_w= [], [], []
        E= data.edge_index.shape[1]
        for i in range(0, E, 2):
            u= data.edge_index[0, i].item()
            v= data.edge_index[1, i].item()
            w= data.edge_weight[i].item()
            pos_u.append(u)
            pos_v.append(v)
            pos_w.append(np.log1p(w))

        loss_synergy= 0.0
        if len(pos_u)>0:
            pos_u_t= torch.tensor(pos_u, dtype=torch.long, device=device)
            pos_v_t= torch.tensor(pos_v, dtype=torch.long, device=device)
            pos_w_t= torch.tensor(pos_w, dtype=torch.float, device=device)

            num_pos= len(pos_u)
            num_neg= int(num_pos* neg_factor)
            neg_ed= sample_negative_edges(data.edge_index, data.x.shape[0], num_neg)
            if neg_ed:
                neg_u= [p[0] for p in neg_ed]
                neg_v= [p[1] for p in neg_ed]
                neg_w= [0.0]* len(neg_ed)
                neg_u_t= torch.tensor(neg_u, dtype=torch.long, device=device)
                neg_v_t= torch.tensor(neg_v, dtype=torch.long, device=device)
                neg_w_t= torch.tensor(neg_w, dtype=torch.float, device=device)

                all_u= torch.cat([pos_u_t, neg_u_t], dim=0)
                all_v= torch.cat([pos_v_t, neg_v_t], dim=0)
                all_w= torch.cat([pos_w_t, neg_w_t], dim=0)

                hu= node_emb[all_u]
                hv= node_emb[all_v]
                synergy_pred= model.predict_synergy(hu, hv)
                loss_synergy= F.mse_loss(synergy_pred, all_w)
            else:
                hu= node_emb[pos_u_t]
                hv= node_emb[pos_v_t]
                synergy_pred= model.predict_synergy(hu, hv)
                loss_synergy= F.mse_loss(synergy_pred, pos_w_t)

        loss= loss_node + loss_synergy
        loss.backward()
        optimizer.step()

        if (epoch+1)%10==0:
            print(f"Epoch {epoch+1}/{epochs} "
                  f"NodeLoss={loss_node.item():.4f} "
                  f"SynergyLoss={loss_synergy:.4f} "
                  f"Combined={loss.item():.4f}")

    return model, mean_val, std_val

In [None]:
def build_synergy_matrix(model, data, freq_map, node_labels_np, mean_val, std_val):
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    data= data.to(device)
    with torch.no_grad():
        node_emb, node_score_t= model(data)
    node_emb= node_emb.cpu()
    node_score= node_score_t.cpu().numpy()

    node_score= node_score* std_val + mean_val

    num_cards= len(cards_data)
    synergy_matrix= np.zeros((num_cards, num_cards), dtype=np.float32)

    def synergy_value(i, j):
        hi= node_emb[i]
        hj= node_emb[j]
        val= model.predict_synergy(
            hi.unsqueeze(0).to(device),
            hj.unsqueeze(0).to(device)
        ).item()
        return val

    for i in range(num_cards):
        for j in range(i+1, num_cards):
            sv= synergy_value(i,j)
            synergy_matrix[i,j]= sv
            synergy_matrix[j,i]= sv

    # diagonal => node usage
    for i in range(num_cards):
        synergy_matrix[i,i]= max(0.1, node_score[i])
        if freq_map[cards_data[i]["id"]]==0:
            node_score[i]-=2.0

    return synergy_matrix, node_score

def analyze_label_usage_for_decks(relevant_decks):
    from collections import defaultdict
    label_count= defaultdict(float)
    deck_count= 0
    all_ids= {c["id"] for c in cards_data}

    for d in relevant_decks:
        lid= find_leader_in_deck(d)
        if not lid:
            continue
        deck_count+=1
        cset= set()
        for entry in d:
            cid = list(entry.keys())[0]
            cnt= entry[cid]
            if cid in all_ids and cid!= lid and cnt>0:
                cset.add(cid)
        for cid in cset:
            labs= cards_data[card_id_to_index[cid]].get("labels", [])
            for la in labs:
                label_count[la]+=1

    if deck_count>0:
        for la in label_count:
            label_count[la]/= deck_count
    return dict(label_count)

def incorporate_label_synergy(cards_data, synergy_matrix, label_usage, factor=0.1):
    n= len(cards_data)
    for i in range(n):
        labs_i= set(cards_data[i].get("labels", []))
        for j in range(i+1, n):
            labs_j= set(cards_data[j].get("labels", []))
            shared= labs_i.intersection(labs_j)
            bonus=0.0
            for lab in shared:
                if lab in label_usage:
                    bonus+= label_usage[lab]* factor
            synergy_matrix[i,j]+= bonus
            synergy_matrix[j,i]+= bonus

def analyze_cost_distribution(relevant_decks):
    from collections import defaultdict
    cost_map= defaultdict(float)
    deck_count=0
    all_ids= {c["id"] for c in cards_data}
    for d in relevant_decks:
        lid= find_leader_in_deck(d)
        if not lid:
            continue
        deck_count+=1
        for entry in d:
            cid= list(entry.keys())[0]
            cnt= entry[cid]
            if cid in all_ids and cid!= lid and cnt>0:
                cst= int(cards_data[card_id_to_index[cid]].get("cost", 0))
                cost_map[cst]+= cnt
    if deck_count>0:
        for cst in cost_map:
            cost_map[cst]/= deck_count
    return dict(cost_map)

def analyze_type_usage(relevant_decks):
    from collections import defaultdict
    tmap= defaultdict(float)
    deck_count=0
    all_ids= {c["id"] for c in cards_data}
    for d in relevant_decks:
        lid= find_leader_in_deck(d)
        if not lid:
            continue
        deck_count+=1
        for entry in d:
            cid= list(entry.keys())[0]
            cnt= entry[cid]
            if cid in all_ids and cid!= lid and cnt>0:
                t = cards_data[card_id_to_index[cid]].get("type","Character")
                tmap[t]+= cnt
    if deck_count>0:
        for t in tmap:
            tmap[t]/= deck_count
    return dict(tmap)

In [None]:
def cost_curve_penalty(deck, cards_data, meta_cost):
    from collections import defaultdict
    cdist= defaultdict(int)
    for i, cpy in enumerate(deck):
        if cpy>0:
            cst= int(cards_data[i].get("cost", 0))
            cdist[cst]+= cpy
    penalty= 0.0
    for cst, avg_val in meta_cost.items():
        diff= abs(cdist[cst]- avg_val)
        penalty-= diff*0.5
    return penalty

def type_usage_bonus(deck, cards_data, meta_type):
    from collections import defaultdict
    tdist= defaultdict(int)
    for i, cpy in enumerate(deck):
        if cpy>0:
            tp= cards_data[i].get("type","Character")
            tdist[tp]+= cpy
    bonus=0.0
    for t, avg_val in meta_type.items():
        diff= abs(tdist[t]- avg_val)
        bonus-= diff*0.3
    return bonus

def final_deck_fitness(deck,
                       node_score, synergy_matrix,
                       archetype_usage_array,
                       cards_data,
                       meta_cost, meta_type,
                       alpha=1.0,
                       beta=3.0,
                       archetype_weight=1.0,
                       singleton_penalty_factor=2.0):
    synergy_val=0.0
    node_val=0.0
    cohesion_val=0.0
    N= len(deck)
    for i in range(N):
        ci= deck[i]
        if ci>0:
            node_val+= node_score[i]* ci
            synergy_val+= synergy_matrix[i,i]* (ci*(ci-1)/2)
            for j in range(i+1, N):
                cj= deck[j]
                if cj>0:
                    synergy_val+= synergy_matrix[i,j]* (ci*cj)
            cohesion_val+= archetype_usage_array[i]* ci

    base= alpha* node_val + beta* synergy_val
    ccurve= cost_curve_penalty(deck, cards_data, meta_cost)
    tbonus= type_usage_bonus(deck, cards_data, meta_type)

    # stronger singletons penalty
    single_pen= 0.0
    for i, cpy in enumerate(deck):
        if cpy==1:
            single_pen-= synergy_matrix[i,i]* singleton_penalty_factor

    return base + ccurve + tbonus + single_pen + (cohesion_val* archetype_weight)

In [None]:
def valid_card_for_leader(card, leader_colors):
    if card["id"] in BANNED_CARDS:
        return False
    if card["type"]=="Leader":
        return False
    cclr= get_card_colors(card)
    return any(x in cclr for x in leader_colors)

def build_valid_indices(cards_data, leader_colors):
    v=[]
    for i, c in enumerate(cards_data):
        if valid_card_for_leader(c, leader_colors):
            v.append(i)
    return v

def random_deck(valid_idxs, total_len, target=50):
    d= [0]* total_len
    s=0
    while s< target:
        pick= random.choice(valid_idxs)
        if d[pick]<4:
            d[pick]+=1
            s+=1
    return d

def repair_deck(deck, valid_idxs, target=50):
    newd= deck[:]
    s= sum(newd)
    for i in range(len(newd)):
        if i not in valid_idxs:
            s-= newd[i]
            newd[i]=0
    while s> target:
        cands= [ix for ix in valid_idxs if newd[ix]>0]
        if not cands:
            break
        rm= random.choice(cands)
        newd[rm]-=1
        s-=1
    while s< target:
        cands= [ix for ix in valid_idxs if newd[ix]<4]
        if not cands:
            break
        ad= random.choice(cands)
        newd[ad]+=1
        s+=1
    return newd

def tournament_selection(pop, fits, k=3):
    best_idx= None
    best_fit= None
    for _ in range(k):
        r= random.randrange(len(pop))
        if best_fit is None or fits[r]> best_fit:
            best_fit= fits[r]
            best_idx= r
    return pop[best_idx], best_fit

def crossover(a, b):
    c=[0]* len(a)
    for i in range(len(a)):
        c[i] = a[i] if random.random()<0.5 else b[i]
    return c

def mutate(deck, valid_idxs, rate=0.2):
    d= deck[:]
    for i in range(len(d)):
        if random.random()< rate:
            if i in valid_idxs:
                if d[i]==0:
                    d[i]=1
                elif d[i]<4:
                    if random.random()<0.5:
                        d[i]= min(4, d[i]+1)
                    else:
                        d[i]= max(0, d[i]-1)
                else:
                    d[i]=3
            else:
                d[i]=0
    return d

def synergy_local_search(deck, valid_idxs, fitnessFn, synergy_matrix, iterations=30):
    best= deck[:]
    best_fit= fitnessFn(best)

    synergy_thresh= 2.0
    max_expansions= 6

    for _ in range(iterations):
        # random tweak
        newd= best[:]
        choice= random.random()
        s= sum(newd)
        if choice<0.4:
            if s<50:
                cands= [ix for ix in valid_idxs if newd[ix]<4]
                if cands:
                    pick= random.choice(cands)
                    newd[pick]+=1
        elif choice<0.8:
            cands= [ix for ix in valid_idxs if newd[ix]>0]
            if cands:
                pick= random.choice(cands)
                newd[pick]-=1
        else:
            cands_rm= [ix for ix in valid_idxs if newd[ix]>0]
            cands_ad= [ix for ix in valid_idxs if newd[ix]<4]
            if cands_rm and cands_ad:
                rr= random.choice(cands_rm)
                aa= random.choice(cands_ad)
                newd[rr]-=1
                newd[aa]+=1

        newd= repair_deck(newd, valid_idxs, 50)
        newf= fitnessFn(newd)
        if newf> best_fit:
            best= newd
            best_fit= newf

        # synergy expansions
        expansions=0
        synergy_pairs=[]
        N= len(newd)
        for i in range(N):
            for j in range(i+1, N):
                val= synergy_matrix[i,j]
                if val> synergy_thresh:
                    synergy_pairs.append((val,i,j))
        synergy_pairs.sort(key=lambda x: x[0], reverse=True)

        for val, i2, j2 in synergy_pairs:
            if expansions>= max_expansions:
                break
            while sum(best)<50 and best[i2]<4 and best[j2]<4:
                best[i2]+=1
                best[j2]+=1
                best= repair_deck(best, valid_idxs, 50)
                newf2= fitnessFn(best)
                if newf2> best_fit:
                    best_fit= newf2
                else:
                    best[i2]-=1
                    best[j2]-=1
                    break
            expansions+=1

    return best, best_fit

In [None]:
AGGRO_CORE = [
    # example staple IDs for an Aggro skeleton
    # "OP01-006",  # Otama
    # "OP01-016",  # Nami
    # "ST01-006",  # Tony Tony Chopper
]

def seed_archetype_skeleton(valid_idxs):
    deck= [0]* len(cards_data)
    ccount=0
    for cid in AGGRO_CORE:
        if cid in card_id_to_index:
            idx= card_id_to_index[cid]
            if idx in valid_idxs:
                deck[idx]=4
                ccount+=4
    return deck, ccount

In [None]:
def build_deck(leader_id="OP01-001",
               pop_size=50,
               generations=100,
               synergy_weight=2.0,
               known_synergy_weight=5.0,
               alpha=1.0,
               beta=3.0,
               archetype_weight=1.0,
               singleton_penalty_factor=2.0,
               use_skeleton=True):
    # find leader
    leader_card= next((c for c in cards_data if c["id"]== leader_id), None)
    if not leader_card:
        raise ValueError(f"Leader {leader_id} not found!")
    leader_color= get_card_colors(leader_card)
    leader_arche= leader_card.get("archetype","").strip()

    print(f"Building deck for Leader {leader_id} (archetype={leader_arche}, color={leader_color})")

    # parse known synergy edges
    known_synergy_edges= parse_known_synergies(
        known_synergies_data,
        synergy_weight= known_synergy_weight,
        treat_as_undirected=True
    )

    # gather relevant decks
    relevant_decks= gather_decks_for_color_archetype(leader_color, leader_arche, decks_data)
    if not relevant_decks:
        print("No relevant decks found, fallback to all decks.")
        relevant_decks= decks_data

    # synergy graph
    synergy_graph= build_synergy_graph(
        relevant_decks,
        leader_arche,
        synergy_weight= synergy_weight,
        known_synergy_edges= known_synergy_edges
    )
    freq_map= compute_card_frequency(relevant_decks)
    archetype_usage_map= compute_archetype_usage_ratio(relevant_decks, leader_arche)

    # build PyG data
    data_pyg, node_labels_np= build_pyg_data(cards_data, synergy_graph, freq_map, archetype_usage_map)

    # create GNN
    in_dim= 4 + NUM_TRAITS
    model= MultiTaskGNN(in_dim=in_dim, hidden_dim=64, emb_dim=32, dropout=0.3, num_layers=3)
    model, mean_val, std_val= train_gnn(model, data_pyg, node_labels_np, epochs=100, lr=0.01, neg_factor=2.0)

    synergy_matrix, node_score= build_synergy_matrix(model, data_pyg, freq_map, node_labels_np, mean_val, std_val)

    # optional label synergy
    label_usage= analyze_label_usage_for_decks(relevant_decks)
    incorporate_label_synergy(cards_data, synergy_matrix, label_usage, factor=0.1)

    meta_cost= analyze_cost_distribution(relevant_decks)
    meta_type= analyze_type_usage(relevant_decks)

    # build valid idx
    N= len(cards_data)
    valid_idxs= []
    for i, card in enumerate(cards_data):
        if card["id"] not in BANNED_CARDS and card["type"]!="Leader":
            cclr= get_card_colors(card)
            if any(x in cclr for x in leader_color):
                valid_idxs.append(i)

    def fitnessFn(deck):
        arch_arr= [archetype_usage_map[cards_data[i]["id"]] for i in range(N)]
        return final_deck_fitness(
            deck= deck,
            node_score= node_score,
            synergy_matrix= synergy_matrix,
            archetype_usage_array= arch_arr,
            cards_data= cards_data,
            meta_cost= meta_cost,
            meta_type= meta_type,
            alpha= alpha,
            beta= beta,
            archetype_weight= archetype_weight,
            singleton_penalty_factor= singleton_penalty_factor
        )

    # GA initialization
    population= []
    half_pop= pop_size//2
    # skeleton
    if use_skeleton and leader_arche.lower()=="aggro":
        for _ in range(half_pop):
            base, ccount= seed_archetype_skeleton(valid_idxs)
            while ccount<50:
                pick= random.choice(valid_idxs)
                if base[pick]<4:
                    base[pick]+=1
                    ccount+=1
            population.append(base)
        for _ in range(pop_size- half_pop):
            population.append(random_deck(valid_idxs, N, 50))
    else:
        for _ in range(pop_size):
            population.append(random_deck(valid_idxs, N, 50))

    best_global= None
    best_global_fit= None

    for gen in range(1, generations+1):
        fits= [fitnessFn(d) for d in population]
        best_idx= np.argmax(fits)
        best_fit= fits[best_idx]
        if best_global_fit is None or best_fit> best_global_fit:
            best_global_fit= best_fit
            best_global= population[best_idx][:]

        # evolve
        new_pop= []
        new_pop.append(best_global)
        while len(new_pop)< pop_size:
            pA,_= tournament_selection(population, fits, k=3)
            pB,_= tournament_selection(population, fits, k=3)
            c= crossover(pA, pB)
            c= mutate(c, valid_idxs, rate=0.2)
            c= repair_deck(c, valid_idxs, 50)
            new_pop.append(c)
        population= new_pop

        if gen%10==0:
            print(f"Gen {gen}/{generations} => best in pop= {best_fit:.2f}, global= {best_global_fit:.2f}")

    final_deck, final_fit= synergy_local_search(best_global, valid_idxs, fitnessFn, synergy_matrix, 30)
    if final_fit> best_global_fit:
        best_global_fit= final_fit
        best_global= final_deck

    # final
    print("\n=== FINAL DECK ===")
    print(f"Leader: {leader_card['id']}  Archetype: {leader_arche}  Color(s): {leader_color}")
    print(f"Final Score = {best_global_fit:.2f}")
    csum=0
    for i, cpy in enumerate(best_global):
        if cpy>0:
            cid= cards_data[i]["id"]
            name= cards_data[i]["name"]
            print(f"{cpy}x {cid} ({name})")
            csum+= cpy
    print("Total non-leader cards:", csum)
    return best_global, best_global_fit


Loaded 1301 cards, 1732 decks, and a known synergy file.
Number of unique traits: 168


In [26]:
build_deck(
    leader_id="OP01-001",
    pop_size=50,
    generations=200,
    synergy_weight=2.0,       # synergy edges from matching archetype decks
    known_synergy_weight=5.0, # synergy edges from known combos
    alpha=1.0,
    beta=3.0,
    archetype_weight=1.0,
    singleton_penalty_factor=2.0,
    use_skeleton=True
)

Building deck for Leader OP01-001 (archetype=Aggro, color=['Red'])
Epoch 10/100 NodeLoss=301.7884 SynergyLoss=27.1107 Combined=328.8991
Epoch 20/100 NodeLoss=37.6258 SynergyLoss=4.4945 Combined=42.1204
Epoch 30/100 NodeLoss=4.2777 SynergyLoss=2.6202 Combined=6.8978
Epoch 40/100 NodeLoss=1.4828 SynergyLoss=2.0933 Combined=3.5761
Epoch 50/100 NodeLoss=2.2529 SynergyLoss=1.6593 Combined=3.9122
Epoch 60/100 NodeLoss=2.3556 SynergyLoss=0.9936 Combined=3.3492
Epoch 70/100 NodeLoss=3.8357 SynergyLoss=1.0875 Combined=4.9232
Epoch 80/100 NodeLoss=1.3941 SynergyLoss=0.8972 Combined=2.2913
Epoch 90/100 NodeLoss=1.6850 SynergyLoss=0.8229 Combined=2.5080
Epoch 100/100 NodeLoss=1.1464 SynergyLoss=0.8197 Combined=1.9661
Gen 10/200 => best in pop= 7694.47, global= 7694.47
Gen 20/200 => best in pop= 7964.90, global= 7964.90
Gen 30/200 => best in pop= 7964.90, global= 7964.90
Gen 40/200 => best in pop= 7964.90, global= 7964.90
Gen 50/200 => best in pop= 8177.25, global= 8177.25
Gen 60/200 => best in pop

([0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  2,
  0,
  1,
  0,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  2,
  1,
  0,
  0,
  2,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,


In [None]:
# print heuristics of leader
leader_id = "OP01-001"

print("Leader:", leader_id)
leader_card = next((c for c in cards_data if c["id"] == leader_id), None)
leader_color = get_card_colors(leader_card)
leader_arche = leader_card.get("archetype", "").strip()

print("Leader Color:", leader_color)
print("Leader Archetype:", leader_arche)

# get cost distribution
relevant_decks = gather_decks_for_color_archetype(leader_color, leader_arche, decks_data)
cost_dist = analyze_cost_distribution(relevant_decks)
print("Cost Distribution:", cost_dist)

# get type usage
type_dist = analyze_type_usage(relevant_decks)
print("Type Usage:", type_dist)

# get label usage
label_usage = analyze_label_usage_for_decks(relevant_decks)
print("Label Usage:", label_usage)


Leader: OP01-001
Leader Color: ['Red']
Leader Archetype: Aggro
Cost Distribution: {1: 19.177570093457945, 3: 7.373831775700935, 5: 4.621495327102804, 2: 6.901869158878505, 4: 9.80373831775701, 9: 0.9719626168224299, 7: 0.3364485981308411, 6: 0.14485981308411214, 10: 0.1822429906542056, 0: 0.3598130841121495, 8: 0.028037383177570093}
Type Usage: {'Character': 42.22429906542056, 'Event': 7.668224299065421, 'Stage': 0.009345794392523364}
Label Usage: {'Vanilla': 0.43457943925233644, 'Low Cost': 10.542056074766355, 'Debuff Power': 2.7383177570093458, 'Mid Cost': 4.635514018691588, 'Rush': 2.0, 'Blocker': 2.308411214953271, 'Counter': 1.3878504672897196, 'Trigger': 3.350467289719626, 'Buff Power': 2.6448598130841123, 'Trash Interaction': 0.3364485981308411, 'Removal': 3.135514018691589, 'Searcher': 1.1869158878504673, 'High Cost': 0.5934579439252337, 'Draw': 0.677570093457944, 'Protection': 0.02336448598130841, 'Discard': 1.294392523364486, 'Summon': 1.0420560747663552, 'Leader Locked': 0.2

In [1]:
def plot_deck_synergy_graph(synergy_matrix, cards_data, threshold=2.0):
    # use cards as nodes with id and name as labels
    G = nx.Graph()
    for c in cards_data:
        G.add_node(c["id"], name=c["name"])
    
    # add edges with synergy weight as label
    N = len(cards_data)
    for i in range(N):
        for j in range(i+1, N):
            if synergy_matrix[i, j] > threshold:
                G.add_edge(cards_data[i]["id"], cards_data[j]["id"], weight=synergy_matrix[i, j])
            
    # draw
    pos = nx.spring_layout(G, seed=42)
    edge_labels = nx.get_edge_attributes(G, 'weight')
    plt.figure(figsize=(12, 12))
    nx.draw(G, pos, with_labels=True, node_size=3000, font_size=10, font_color="black")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
    plt.title("Card Synergy Graph")
    plt.show()
    

In [2]:
# check performance of one of the relevant decks
deck = relevant_decks[0]
lid = find_leader_in_deck(deck)

plot_deck_synergy_graph(synergy_matrix, cards_data, threshold=2.0)

NameError: name 'relevant_decks' is not defined