In [1]:
# pip install flash_attn

In [2]:
import polars as pl
import orjson
import torch
# import yaml
from tqdm import tqdm

import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

In [3]:
df = pl.read_parquet("mtg_data_2.parquet").sample(fraction=1, shuffle=True, seed=42)

df

name,scryfallId,manaCost,type,text,power,toughness,loyalty,rarities,sets
str,str,str,str,str,str,str,str,list[enum],list[enum]
"""Vilis, Broker of Blood""","""d68ce39b-0847-4132-b3b5-8bc25f…","""{5}{B}{B}{B}""","""Legendary Creature — Demon""","""Flying\n{B}, Pay 2 life: Targe…","""8""","""8""",,"[""rare""]","[""PM20"", ""M20"", … ""J25""]"
"""Cruel Celebrant""","""b3f8be99-a398-4664-beb6-3e98d6…","""{W}{B}""","""Creature — Vampire""","""Whenever ~ or another creature…","""1""","""2""",,"[""uncommon""]","[""WAR"", ""PLST"", ""LCC""]"
"""Leyline of Lifeforce""","""f7caffa7-29bd-455c-9770-94a0ad…","""{2}{G}{G}""","""Enchantment""","""If ~ is in your opening hand, …",,,,"[""rare""]","[""GPT""]"
"""Lumbering Falls""","""e3c4e109-2aa7-4f5f-8b59-7b89c0…",,"""Land""","""~ enters tapped.\n{T}: Add {G}…",,,,"[""rare""]","[""PBFZ"", ""BFZ"", … ""PIO""]"
"""Grand Warlord Radha""","""986981fa-a744-45d8-81c7-68ef33…","""{2}{R}{G}""","""Legendary Creature — Elf Warri…","""Haste\nWhenever one or more cr…","""3""","""4""",,"[""rare""]","[""PDOM"", ""DOM""]"
…,…,…,…,…,…,…,…,…,…
"""Mimeofacture""","""f4ada33a-5b7c-426a-8416-4ce01b…","""{3}{U}""","""Sorcery""","""Replicate {3}{U} (When you cas…",,,,"[""rare""]","[""GPT""]"
"""Scar""","""b34e3f7c-468a-456c-8ed0-0cb88f…","""{B/R}""","""Instant""","""Put a -1/-1 counter on target …",,,,"[""common""]","[""SHM""]"
"""Field Marshal""","""0b81e16f-8e5c-42e2-9d4e-220eb3…","""{1}{W}{W}""","""Creature — Human Soldier""","""Other Soldier creatures get +1…","""2""","""2""",,"[""rare""]","[""CSP"", ""10E"", ""SLD""]"
"""Flanking Licid""","""21bda5a9-4faf-4f7c-a04b-209346…","""{1}{R}""","""Summon Licid""","""{R}, {T}: ~ loses this ability…",,,,"[""rare""]","[""MB2""]"


In [4]:
docs = []
for row in df.iter_rows(named=True):
    row_dict = {k: v for k, v in row.items() if v is not None and k != "scryfallId"}
    # row_str = yaml.dump(row_dict, sort_keys=False, allow_unicode=True)
    row_str = orjson.dumps(row_dict, option=orjson.OPT_INDENT_2).decode("utf-8")
    docs.append(row_str)

print(docs[0])

{
  "name": "Vilis, Broker of Blood",
  "manaCost": "{5}{B}{B}{B}",
  "type": "Legendary Creature — Demon",
  "text": "Flying\\n{B}, Pay 2 life: Target creature gets -1/-1 until end of turn.\\nWhenever you lose life, draw that many cards. (Damage causes loss of life.)",
  "power": "8",
  "toughness": "8",
  "rarities": [
    "rare"
  ],
  "sets": [
    "PM20",
    "M20",
    "SLD",
    "GN3",
    "J25"
  ]
}


In [5]:
model_path = "Alibaba-NLP/gte-modernbert-base"
device = "cuda:0"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
_ = model.to(device)

torch.set_float32_matmul_precision('high')

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [6]:
tokenized_docs = tokenizer(
    docs[0], max_length=8192, padding=True, truncation=True, return_tensors="pt"
).to(device)

tokenized_docs

{'input_ids': tensor([[50281,    92,   187, 50276,     3,  1590,  1381,   346,    55, 27154,
            13,  4819,  6426,   273, 14169,   995,   187, 50276,     3,  1342,
            66, 25997,  1381, 36028,    22,  1217,    35,  1217,    35,  1217,
            35, 32722,   187, 50276,     3,   881,  1381,   346, 18596,   423,
           552, 13489,   459,  1905,  4281,   251,   995,   187, 50276,     3,
          1156,  1381,   346,    39,  2943,  3353,    79,    92,    35,  2023,
         12286,   374,  1495,    27, 17661, 15906,  4850,   428,    18,  7448,
            18,  1919,   990,   273,  1614,    15,  3353,    79, 43835,   368,
          7168,  1495,    13,  3812,   326,  1142,  8364,    15,   313, 21727,
           486,  5997,  2957,   273,  1495,  2698,   995,   187, 50276,     3,
          9177,  1381,   346,    25,   995,   187, 50276,     3,    85,   602,
          1255,  1381,   346,    25,   995,   187, 50276,     3, 23537,  1005,
          1381,   544,   187, 50274,  

In [7]:
dataloader = torch.utils.data.DataLoader(docs, batch_size=64,
                                         shuffle=False,
                                         pin_memory=True,
                                         pin_memory_device=device)

dataset_embeddings = []
for batch in tqdm(dataloader, smoothing=0):
    tokenized_batch = tokenizer(
        batch, max_length=8192, padding=True, truncation=True, return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**tokenized_batch)
        embeddings = outputs.last_hidden_state[:, 0].detach().cpu()
    dataset_embeddings.append(embeddings)

dataset_embeddings = torch.cat(dataset_embeddings)
dataset_embeddings = F.normalize(dataset_embeddings, p=2, dim=1)
dataset_embeddings.size()

100%|██████████| 495/495 [01:15<00:00,  6.60it/s]


torch.Size([31650, 768])

In [8]:
df_2 = df.with_columns(embedding=dataset_embeddings.cpu().numpy()).sort("name")

df_2

name,scryfallId,manaCost,type,text,power,toughness,loyalty,rarities,sets,embedding
str,str,str,str,str,str,str,str,list[enum],list[enum],"array[f32, 768]"
"""""Ach! Hans, Run!""""","""84f2c8f5-8e11-4639-b7de-00e4a2…","""{2}{R}{R}{G}{G}""","""Enchantment""","""At the beginning of your upkee…",,,,"[""rare""]","[""UNH""]","[0.021458, -0.036102, … -0.00607]"
"""""Brims"" Barone, Midway Mobster""","""68832214-2943-4253-8884-ffa490…","""{3}{W}{B}""","""Legendary Creature — Human Rog…","""When ~ enters, put a +1/+1 cou…","""5""","""4""",,"[""uncommon""]","[""UNF""]","[0.000261, 0.025207, … 0.005641]"
"""""Intimidation Tactics""""","""9b4e6022-44d2-4dfe-8f7a-51581e…","""{B}""","""Sorcery""","""Target opponent reveals their …",,,,"[""uncommon""]","[""DFT""]","[-0.031034, -0.004097, … 0.007289]"
"""""Lifetime"" Pass Holder""","""42293306-aaea-4542-8df4-813823…","""{B}""","""Creature — Zombie Guest""","""~ enters tapped.\nWhen ~ dies,…","""2""","""1""",,"[""rare""]","[""UNF""]","[0.001258, -0.004219, … 0.013614]"
"""""Name Sticker"" Goblin""","""fd1442b4-da59-4042-835f-143c8d…","""{2}{R}""","""Creature — Goblin Guest""","""When this creature enters from…","""2""","""2""",,"[""common""]","[""UNF""]","[-0.03243, -0.006905, … -0.026888]"
…,…,…,…,…,…,…,…,…,…,…
"""Éomer, King of Rohan""","""f2c11695-f22b-44d5-937c-2578f2…","""{3}{R}{W}""","""Legendary Creature — Human Nob…","""Double strike\n~ enters with a…","""2""","""2""",,"[""rare""]","[""LTC""]","[0.014415, 0.012861, … 0.019452]"
"""Éomer, Marshal of Rohan""","""0bd31ce9-9551-4efe-8bd2-b97d8e…","""{2}{R}{R}""","""Legendary Creature — Human Kni…","""Haste\nWhenever one or more ot…","""4""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.022492, 0.017429, … 0.046833]"
"""Éowyn, Fearless Knight""","""c1b37891-5ed9-47e4-8d2f-c2bfd8…","""{2}{R}{W}""","""Legendary Creature — Human Kni…","""Haste\nWhen ~ enters, exile ta…","""3""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.033708, 0.007089, … 0.02833]"
"""Éowyn, Lady of Rohan""","""e59710c4-24de-419e-a8a0-e8392d…","""{2}{W}""","""Legendary Creature — Human Nob…","""At the beginning of combat on …","""2""","""4""",,"[""uncommon""]","[""LTR""]","[-0.001433, 0.00514, … 0.048259]"


In [9]:
df_2.write_parquet("mtg_embeddings.parquet")

In [10]:
!gsutil cp mtg_embeddings.parquet gs://maxw-imdb-embeddings/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Copying file://mtg_embeddings.parquet [Content-Type=application/octet-stream]...
- [1 files][ 88.1 MiB/ 88.1 MiB]                                                
Operation completed over 1 objects/88.1 MiB.                                     
