In [2]:
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    Collection,
    DataType
)
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import pandas as pd
import torch
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
connections.connect(alias="default", host="127.0.0.1", port="19530")

In [7]:
collection_name = "kursinis2"

In [9]:
#Jeigu norime iš naujo sukurti duomenų bazę
if utility.has_collection(collection_name):
    utility.drop_collection(collection_name)

In [10]:
fields = [
    FieldSchema(name="id",         dtype=DataType.INT64,        is_primary=True,  auto_id=False),
    FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1024,           auto_id=False),
    FieldSchema(name="language", dtype=DataType.VARCHAR, max_length=16),
    FieldSchema(name="function_code", dtype=DataType.VARCHAR, max_length=65535)
]
schema = CollectionSchema(fields, description="Kursinio embeddings")

demo_coll = Collection(name=collection_name, schema=schema)
print(f"Kolekcija {collection_name} sukurta su šiais laukais {[f.name for f in schema.fields]}")

Kolekcija kursinis2 sukurta su šiais laukais ['id', 'embeddings', 'language', 'function_code']


In [3]:
files = glob.glob("*.parquet")
dfs = [pd.read_parquet(f) for f in files]
combined_df = pd.concat(dfs, ignore_index=True)
combined_df.insert(0, 'id', combined_df.index + 1)

In [17]:
tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-Code-400M_R", trust_remote_code=True)
model     = AutoModel.from_pretrained("Salesforce/SFR-Embedding-Code-400M_R", trust_remote_code=True)

def embed_code(code: str):
    inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=8192)
    with torch.no_grad():
        hidden = model(**inputs).last_hidden_state
    return hidden[0, 0].cpu().numpy().tolist()

In [12]:
batch_size = 500
all_ids        = combined_df['id'].tolist()
all_embeddings = combined_df['embedding'].tolist()
all_languages  = combined_df['language'].tolist()
all_codes      = combined_df['function_code'].tolist()

for i in tqdm(range(0, len(all_ids), batch_size), desc="Inserting batches"):
    batch_ids       = all_ids[i : i + batch_size]
    batch_embs      = all_embeddings[i : i + batch_size]
    batch_langs     = all_languages[i : i + batch_size]
    batch_functions = all_codes[i : i + batch_size]

    demo_coll.insert([
        batch_ids,
        batch_embs,
        batch_langs,
        batch_functions
    ])

Inserting batches: 100%|██████████| 200/200 [00:47<00:00,  4.20it/s]


In [13]:
demo_coll.flush()

In [14]:
#hnsw_params = {
#     "index_type": "HNSW",
#     "metric_type": "COSINE",
#     "params": {
#         "M": 16,
#         "efConstruction": 200
#     }
# }

ivf_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 1265}
}

#demo_coll.create_index("embeddings", hnsw_params)
#demo_coll.load()
demo_coll.create_index("embeddings", ivf_params)
demo_coll.load()

In [15]:
utility.list_indexes(collection_name)

['embeddings']

In [16]:
#Jeigu norime pakeisti indeksą, pradžioje reikia ištrinti esamą
demo_coll.release()
demo_coll.drop_index()

In [43]:
results = demo_coll.search(
    data=[embed_code("""def answer(a, b):
    if a < b:
        return 1
    elif a > b:
        return 2
    else:
        return 3""")],
    anns_field="embeddings",
    param={"metric_type": "L2", "params": {"nprobe": 122}},
    limit=5,
    #expr='language == "C"',
    output_fields=["function_code","language"]
)
print(results) #[0][0]['entity']['function_code']

data: [[{'id': 64, 'distance': 51.35755920410156, 'entity': {'function_code': 'int answer(int a, int b){\n\t\n\tif(a<b){\n\treturn 1;\n\t}else if(a>b){\n\treturn 2;\n\t}else{\n\treturn 3;\n\t}\n}', 'language': 'C'}}, {'id': 94896, 'distance': 157.60035705566406, 'entity': {'function_code': 'def solve(a,b):\n    if a <= b:\n        return(a)\n    else:\n        return(a-1)', 'language': 'Python'}}, {'id': 75060, 'distance': 168.23654174804688, 'entity': {'function_code': '){\n    if(a >= b) {\n    \t\tif(b >= c) {\n    \t\t\treturn b;\n    \t\t}\n    \t\telse if(a <= c) {\n    \t\t\treturn a;\n    \t\t}\n    \t\telse {\n    \t\t\treturn c;\n    \t\t}\n    \t}\n    \telse if(a > c) {\n    \t\treturn a;\n    \t}\n    \telse if(b > c) {\n    \t\treturn c;\n    \t}\n    \telse {\n    \t\treturn b;\n    \t}\n  }\n    \n}', 'language': 'Java'}}, {'id': 87825, 'distance': 170.51821899414062, 'entity': {'function_code': 'def compare(a, b):\n    if a < b:\n        return "a < b"\n    elif a == b

# TESTAVIMAS

In [167]:
index_configs = [
    ("FLAT", {"index_type": "FLAT", "params": {}, "metric_type": "IP"}, {}, None),
    ("IVF_FLAT", {"index_type": "IVF_FLAT", "params": {"nlist": 1265}, "metric_type": "IP"}, {"nprobe": 122}, None),
    ("IVF_SQ8", {"index_type": "IVF_SQ8",  "params": {"nlist": 1265}, "metric_type": "IP"}, {"nprobe": 122}, None),
    ("IVF_PQ", {"index_type": "IVF_PQ",   "params": {"nlist": 1265, "m": 64, "nbits": 8}, "metric_type": "IP"}, {"nprobe": 122}, None),
    ("HNSW", {"index_type": "HNSW", "params": {"M": 16, "efConstruction": 128}, "metric_type": "IP"}, {"ef": 128}, None),
]

In [98]:
import time
import random
from pymilvus import MilvusException
random.seed(42)
Q   = 100        # užklausų skaičius testavimui
K   = 10         # Testavimo metrikoms, K skaičius

indices = random.sample(range(len(all_embeddings)), Q)
query_embeddings = [all_embeddings[i] for i in indices]

In [149]:
import math

idcg = sum(1.0 / math.log2(rank + 1) for rank in range(1, K + 1))

In [168]:
try:
    demo_coll.release()
    demo_coll.drop_index()
except MilvusException as e:

    print("no existing index to drop, skipping")

In [169]:
flat_idx = {
    "index_type": "FLAT",
    "metric_type":"IP",
    "params": {}
}
print("Building ground-truth FLAT index…")
t0 = time.time()
demo_coll.create_index("embeddings", flat_idx)
build_flat = time.time() - t0

demo_coll.load()
t0 = time.time()
flat_res = demo_coll.search(
    data=query_embeddings,
    anns_field="embeddings",
    param={"metric_type":"IP"},
    limit=K
)

lat_flat = (time.time() - t0)/Q*1000
ground_truth = [[hit.id for hit in hits] for hits in flat_res]
print(f"FLAT build {build_flat:.2f}s, {lat_flat:.2f}s/q")

demo_coll.release()
demo_coll.drop_index()

Building ground-truth FLAT index…
FLAT build 0.52s, 17.14ms/q


In [171]:
results = []
for name, idx_def, search_p, device in index_configs:
    print(f"\n>>> Testing {name}")

    if utility.list_indexes(collection_name, field_name="embeddings"):
        demo_coll.release()
        demo_coll.drop_index()

    t0 = time.time()
    demo_coll.create_index("embeddings", idx_def)
    build_t = time.time() - t0
    demo_coll.load()

    t1 = time.time()
    res = demo_coll.search(
        data=query_embeddings,
        anns_field="embeddings",
        param={**search_p, "metric_type":"IP"},
        limit=K
    )
    lat = (time.time() - t1)/Q*1000

    found = [[hit.id for hit in hits] for hits in res]

    # Recall@K skaičiavimas
    matches = sum(len(set(found[i]) & set(ground_truth[i])) for i in range(Q))
    recall = matches / (Q * K)

    # MRR@K skaičiavimas
    reciprocal_ranks = []
    for i in range(Q):
        rr = 0.0
        for rank, vid in enumerate(found[i], start=1):
            if vid in ground_truth[i]:
                rr = 1.0 / rank
                break
        reciprocal_ranks.append(rr)
    mrr = sum(reciprocal_ranks) / Q

    # NDCG@K skaičiavimas
    cumulative_ndcg = 0.0
    for i in range(Q):
        dcg = 0.0
        for rank, vid in enumerate(found[i], start=1):
            if vid in ground_truth[i]:
                dcg += 1.0 / math.log2(rank + 1)
        cumulative_ndcg += (dcg / idcg)
    ndcg = cumulative_ndcg / Q

    results.append({
        "Index":         name,
        "Build (s)":     round(build_t, 2),
        "Latency (ms/q)":round(lat, 2),
        "Recall@K":      round(recall, 4),
        "MRR@K":         round(mrr, 4),
        "NDCG@K":        round(ndcg, 4),
    })

    demo_coll.release()

df = pd.DataFrame(results).sort_values("Recall@K", ascending=False)
print("\n=== Benchmark Results (IP) ===")
print(df)


>>> Testing FLAT

>>> Testing IVF_FLAT

>>> Testing IVF_SQ8

>>> Testing IVF_PQ

>>> Testing HNSW

=== Benchmark Results (IP) ===
      Index  Build (s)  Latency (ms/q)  Recall@K   MRR@K  NDCG@K
0      FLAT       2.08           16.22     1.000  1.0000  1.0000
1  IVF_FLAT      46.00            2.64     0.993  1.0000  0.9955
2   IVF_SQ8      48.55            2.95     0.979  1.0000  0.9857
4      HNSW      57.26            2.64     0.974  0.9900  0.9786
3    IVF_PQ      75.30            0.59     0.506  0.8983  0.5816


# IVF_FLAT TESTAVIMAS

In [25]:
demo_coll.release()
demo_coll.drop_index()

In [26]:
ivf_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 1265}
}

In [27]:
demo_coll.create_index("embeddings", ivf_params)
demo_coll.load()
print("Indeksas sukurtas")

Indeksas sukurtas


In [29]:
test_df = pd.read_excel("test_data.xlsx")
print(test_df.head())

  original_language                                      original_code  \
0                 C  int answer(int a, int b){\n\t\n\tif(a<b){\n\tr...   
1                 C  int answer(int a, int b){\n\t\n\tif(a<b){\n\tr...   
2                 C  int answer(int a, int b){\n\t\n\tif(a<b){\n\tr...   
3                 C  int answer(int a, int b){\n\t\n\tif(a<b){\n\tr...   
4                 C  void insertion_sort(int array[], int size) {\n...   

                                      converted_code converted_language  
0  def answer(a, b):\n    if a < b:\n        retu...             Python  
1  public static int answer(int a, int b) {\n    ...               Java  
2  int answer(int a, int b) {\n    if (a < b) {\n...                C++  
3  public static int Answer(int a, int b) {\n    ...                 C#  
4  def insertion_sort(array, size):\n    for i in...             Python  


In [30]:
code_to_id = {
    code: idx
    for idx, code in zip(combined_df['id'], combined_df['function_code'])
}

In [35]:
TOP_K = 5
search_params = {"metric_type": "L2", "params": {"nprobe": 122}}

correct_top1 = 0
Q = len(test_df)
sum_rr = 0.0

for _, row in tqdm(test_df.iterrows(), total=Q):

    query_emb = embed_code(row['converted_code'])

    expr = f'language == "{row["original_language"]}"'
    res = demo_coll.search(
        data=[query_emb],
        anns_field="embeddings",
        param=search_params,
        limit=TOP_K,
        expr=expr,
        output_fields=["function_code"]
    )

    hits = res[0]
    top1_code = hits[0].entity['function_code']
    print(top1_code)
    returned_codes = [hit.entity['function_code'].strip() for hit in hits]

    original = row['original_code'].strip()
    if returned_codes and returned_codes[0] == original:
        print("ATITINKA")
        correct_top1 += 1

    rr = 0.0
    for rank, code in enumerate(returned_codes, start=1):
        if code == original:
            rr = 1.0 / rank
            break
    sum_rr += rr

top1_accuracy = correct_top1 / Q
mrr_at_5       = sum_rr / Q

print(f"Top-1 Accuracy: {top1_accuracy:.2%}")
print(f"MRR@5:         {mrr_at_5:.4f}")

  2%|▎         | 1/40 [00:00<00:31,  1.23it/s]

int answer(int a, int b){
	
	if(a<b){
	return 1;
	}else if(a>b){
	return 2;
	}else{
	return 3;
	}
}
ATITINKA


  5%|▌         | 2/40 [00:01<00:27,  1.38it/s]

int answer(int a, int b){
	
	if(a<b){
	return 1;
	}else if(a>b){
	return 2;
	}else{
	return 3;
	}
}
ATITINKA


  8%|▊         | 3/40 [00:02<00:26,  1.41it/s]

int answer(int a, int b){
	
	if(a<b){
	return 1;
	}else if(a>b){
	return 2;
	}else{
	return 3;
	}
}
ATITINKA


 10%|█         | 4/40 [00:02<00:26,  1.37it/s]

int answer(int a, int b){
	
	if(a<b){
	return 1;
	}else if(a>b){
	return 2;
	}else{
	return 3;
	}
}
ATITINKA


 12%|█▎        | 5/40 [00:03<00:27,  1.29it/s]

void insertion_sort(int array[], int size) {
	for(int i = 1; i < size; i++) {
		int tmp = array[i];
		int j = i - 1;
		while(j >= 0 && array[j] > tmp) {
			array[j+1] = array[j];
			j--;
		}
		array[j+1] = tmp;

		print(array, size);
	}
}
ATITINKA


 15%|█▌        | 6/40 [00:04<00:28,  1.18it/s]

void insertion_sort(int array[], int size) {
	for(int i = 1; i < size; i++) {
		int tmp = array[i];
		int j = i - 1;
		while(j >= 0 && array[j] > tmp) {
			array[j+1] = array[j];
			j--;
		}
		array[j+1] = tmp;

		print(array, size);
	}
}
ATITINKA


 18%|█▊        | 7/40 [00:05<00:29,  1.12it/s]

void insertion_sort(int array[], int size) {
	for(int i = 1; i < size; i++) {
		int tmp = array[i];
		int j = i - 1;
		while(j >= 0 && array[j] > tmp) {
			array[j+1] = array[j];
			j--;
		}
		array[j+1] = tmp;

		print(array, size);
	}
}
ATITINKA


 20%|██        | 8/40 [00:06<00:30,  1.05it/s]

void insertion_sort(int array[], int size) {
	for(int i = 1; i < size; i++) {
		int tmp = array[i];
		int j = i - 1;
		while(j >= 0 && array[j] > tmp) {
			array[j+1] = array[j];
			j--;
		}
		array[j+1] = tmp;

		print(array, size);
	}
}
ATITINKA


 22%|██▎       | 9/40 [00:07<00:27,  1.12it/s]

private static object GetResult()
        {
            var X = ReadLong();

            var year = 0L;
            var yen = 100L;
            while (yen < X)
            {
                yen = (long)Math.Floor(yen * 1.01D);
                year++;
            }

            return year;
        }
ATITINKA


 25%|██▌       | 10/40 [00:08<00:25,  1.16it/s]

private static object GetResult()
        {
            var X = ReadLong();

            var year = 0L;
            var yen = 100L;
            while (yen < X)
            {
                yen = (long)Math.Floor(yen * 1.01D);
                year++;
            }

            return year;
        }
ATITINKA


 28%|██▊       | 11/40 [00:09<00:24,  1.20it/s]

private static object GetResult()
        {
            var X = ReadLong();

            var year = 0L;
            var yen = 100L;
            while (yen < X)
            {
                yen = (long)Math.Floor(yen * 1.01D);
                year++;
            }

            return year;
        }
ATITINKA


 30%|███       | 12/40 [00:09<00:23,  1.21it/s]

private static object GetResult()
        {
            var X = ReadLong();

            var year = 0L;
            var yen = 100L;
            while (yen < X)
            {
                yen = (long)Math.Floor(yen * 1.01D);
                year++;
            }

            return year;
        }
ATITINKA


 32%|███▎      | 13/40 [00:10<00:21,  1.25it/s]

public static bool Palindrome(string s, int n)
        {
            int mid = (n + 1) / 2;
            for(int i=0; i<mid; i++)
            {
                if (s[i] != s[n - i - 1]) return false;
            }
            return true;

        }
ATITINKA


 35%|███▌      | 14/40 [00:11<00:21,  1.23it/s]

public static bool Palindrome(string s, int n)
        {
            int mid = (n + 1) / 2;
            for(int i=0; i<mid; i++)
            {
                if (s[i] != s[n - i - 1]) return false;
            }
            return true;

        }
ATITINKA


 38%|███▊      | 15/40 [00:12<00:20,  1.23it/s]

public static bool Palindrome(string s, int n)
        {
            int mid = (n + 1) / 2;
            for(int i=0; i<mid; i++)
            {
                if (s[i] != s[n - i - 1]) return false;
            }
            return true;

        }
ATITINKA


 40%|████      | 16/40 [00:13<00:19,  1.22it/s]

public static bool Palindrome(string s, int n)
        {
            int mid = (n + 1) / 2;
            for(int i=0; i<mid; i++)
            {
                if (s[i] != s[n - i - 1]) return false;
            }
            return true;

        }
ATITINKA


 42%|████▎     | 17/40 [00:13<00:17,  1.28it/s]

bool isPrime(int x)
{
  for (int i = 2; i*i <= x; ++i)
  {
    if(x%i==0) return false;
  }
  return true;
}


 45%|████▌     | 18/40 [00:14<00:16,  1.30it/s]

bool isprime(int x){
	for(int i = 2; i*i <= x; i++){
		if(x% i == 0){
			return false;
		}
	}
	return true;
}


 48%|████▊     | 19/40 [00:15<00:15,  1.35it/s]

bool isprime(int x){
	for(int i = 2; i*i <= x; i++){
		if(x% i == 0){
			return false;
		}
	}
	return true;
}


 50%|█████     | 20/40 [00:16<00:14,  1.37it/s]

bool isprime(int x){
    for (int i = 2; i*i <= x; i++){
        if (x%i == 0) return false;
    }
    return true;
}
ATITINKA


 52%|█████▎    | 21/40 [00:16<00:14,  1.33it/s]

int choco(int day, int n){
    int res = 1;
    int cnt = 1;
    while(day + res<= n){
        res += day;
        cnt++;
    }
    return cnt;
}
ATITINKA


 55%|█████▌    | 22/40 [00:17<00:13,  1.34it/s]

int choco(int day, int n){
    int res = 1;
    int cnt = 1;
    while(day + res<= n){
        res += day;
        cnt++;
    }
    return cnt;
}
ATITINKA


 57%|█████▊    | 23/40 [00:18<00:12,  1.35it/s]

int choco(int day, int n){
    int res = 1;
    int cnt = 1;
    while(day + res<= n){
        res += day;
        cnt++;
    }
    return cnt;
}
ATITINKA


 60%|██████    | 24/40 [00:19<00:11,  1.35it/s]

int choco(int day, int n){
    int res = 1;
    int cnt = 1;
    while(day + res<= n){
        res += day;
        cnt++;
    }
    return cnt;
}
ATITINKA


 62%|██████▎   | 25/40 [00:19<00:11,  1.27it/s]

def solve(N: int, R: int):
    if N >= 10:
        print(R)
    else:
        print(R+100*(10-N))
    return


 65%|██████▌   | 26/40 [00:20<00:11,  1.27it/s]

def solve(N: int, R: int):
    if N >= 10:
        print(R)
    else:
        print(R+100*(10-N))
    return


 68%|██████▊   | 27/40 [00:21<00:10,  1.29it/s]

def solve(N: int, R: int):
    if N >= 10:
        print(R)
    else:
        print(R+100*(10-N))
    return


 70%|███████   | 28/40 [00:22<00:09,  1.33it/s]

def solve(N: int, R: int):
    if N >= 10:
        print(R)
    else:
        print(R + 100 * (10 - N))
ATITINKA


 72%|███████▎  | 29/40 [00:22<00:07,  1.39it/s]

def move(q, up, down):
    if q > 0:
        return up * q
    else:
        return down * (-q)
ATITINKA


 75%|███████▌  | 30/40 [00:23<00:07,  1.41it/s]

def move(q, up, down):
    if q > 0:
        return up * q
    else:
        return down * (-q)
ATITINKA


 78%|███████▊  | 31/40 [00:24<00:06,  1.41it/s]

def move(q, up, down):
    if q > 0:
        return up * q
    else:
        return down * (-q)
ATITINKA


 80%|████████  | 32/40 [00:24<00:05,  1.44it/s]

def move(q, up, down):
    if q > 0:
        return up * q
    else:
        return down * (-q)
ATITINKA


 82%|████████▎ | 33/40 [00:25<00:05,  1.39it/s]

static long max(long[] ll){
        long max = ll[0];
        for(long l : ll){
            if(max < l){
                max = l;
            }
        }
        return max;
    }
ATITINKA


 85%|████████▌ | 34/40 [00:26<00:04,  1.40it/s]

static long max(long[] ll){
        long max = ll[0];
        for(long l : ll){
            if(max < l){
                max = l;
            }
        }
        return max;
    }
ATITINKA


 88%|████████▊ | 35/40 [00:27<00:03,  1.30it/s]

static long max(long[] ll){
        long max = ll[0];
        for(long l : ll){
            if(max < l){
                max = l;
            }
        }
        return max;
    }
ATITINKA


 90%|█████████ | 36/40 [00:28<00:03,  1.28it/s]

static long max(long[] ll){
        long max = ll[0];
        for(long l : ll){
            if(max < l){
                max = l;
            }
        }
        return max;
    }
ATITINKA


 92%|█████████▎| 37/40 [00:28<00:02,  1.24it/s]

public void print() {
        for (int i = 0; i < this.size; i++) {
            System.out.print(this.mx[i]);
            if (i == this.size - 1) {
                System.out.println();
            } else {
                System.out.print(" ");
            }
        }
    }
ATITINKA


 95%|█████████▌| 38/40 [00:29<00:01,  1.26it/s]

public void print() {
        for (int i = 0; i < this.size; i++) {
            System.out.print(this.mx[i]);
            if (i == this.size - 1) {
                System.out.println();
            } else {
                System.out.print(" ");
            }
        }
    }
ATITINKA


 98%|█████████▊| 39/40 [00:30<00:00,  1.25it/s]

public void print() {
        for (int i = 0; i < this.size; i++) {
            System.out.print(this.mx[i]);
            if (i == this.size - 1) {
                System.out.println();
            } else {
                System.out.print(" ");
            }
        }
    }
ATITINKA


100%|██████████| 40/40 [00:31<00:00,  1.28it/s]

public void print() {
        for (int i = 0; i < this.size; i++) {
            System.out.print(this.mx[i]);
            if (i == this.size - 1) {
                System.out.println();
            } else {
                System.out.print(" ");
            }
        }
    }
ATITINKA
Top-1 Accuracy: 85.00%
MRR@5:         0.9250



