In [None]:
!pip install sentence-transformers -qq

This is a sample demo on how to use [Flax Code Embeddings](https://huggingface.co/flax-sentence-embeddings/st-codesearch-distilroberta-base) with [Google AI4Code](https://www.kaggle.com/competitions/AI4Code/) dataset. Like with any other SBert model, this is also finetunable to suite the purpose of this competition's task. The rest of the code is a sample usage demo to evaluate this model out of the box.

From a first impression, it seems to do above average with English markdown cells, but not so much with multilingual corpus. This is expected since it is finetuned over Roberta, which is trained specifically on English corpora.

In [None]:
from sentence_transformers import SentenceTransformer, util
import pandas as pd, numpy as np, json
from colorama import Fore, Back, Style

model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
TRAIN_BASE_DIR = "../input/AI4Code/train"

df = pd.read_csv("../input/AI4Code/train_orders.csv")

In [None]:
row = df.sample().iloc[0]
cid = row["id"]
cells = row["cell_order"].split()

with open(f"{TRAIN_BASE_DIR}/{cid}.json") as f:
    dat = json.load(f)

codes =  [dat["source"][cell] for cell in cells if dat["cell_type"][cell]=="code"]
queries = [dat["source"][cell] for cell in cells if dat["cell_type"][cell]!="code"]
code_ids =  [cell for cell in cells if dat["cell_type"][cell]=="code"]
query_ids = [cell for cell in cells if dat["cell_type"][cell]!="code"]

# Ordered cells display

In [None]:
for cell in cells: #Ordered
    content = dat["source"][cell]
    colour = Fore.GREEN if dat["cell_type"][cell]=="code" else Fore.BLUE
    print(f"Cell id: {cell}")
    print(colour + content + Style.RESET_ALL)
    print("$"*50)

# Predict closest code cell for a list of markdown cells

In [None]:
code_emb = model.encode(codes, convert_to_tensor=True)
query_emb = model.encode(queries, convert_to_tensor=True)
query_hits = util.semantic_search(query_emb, code_emb)

op_dat = [{"score": hits[0]["score"], 
           "code_match": code_ids[hits[0]["corpus_id"]], 
           "markdown": qid} for qid, hits in zip(query_ids, query_hits)]
for i, dct in enumerate(op_dat):
    dct["abs_dist"] = abs(cells.index(dct["code_match"]) - cells.index(dct["markdown"]))

print(json.dumps(op_dat, indent=2))