In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import os
from scipy.spatial.distance import cosine
from sklearn.decomposition import PCA
from tqdm import tqdm
import pickle
import faiss
from transformers import AutoTokenizer, AutoModel

# Set up MLP Model

In [3]:
class BERTEmbeddingTransform(object):
    def __init__(self, bert_model, tokenizer, device='cpu'):
        bert_model.eval()
        bert_model = bert_model.to(device)
        bert_model.share_memory()
        self.bert_model = bert_model
        self.tokenizer = tokenizer
        self.device = device
    
    def __call__(self, sample):
        code_tokens=self.tokenizer.tokenize(sample)
        tokens = code_tokens
        tokens_ids=self.tokenizer.convert_tokens_to_ids(tokens)
        done_tok = torch.split(torch.tensor(tokens_ids, device=self.device), 510)
        with torch.no_grad():
            embedings = []
            for input_tok in done_tok:
                input_tok = torch.cat((torch.tensor([0], device=self.device), input_tok, torch.tensor([2], device=self.device)))
                temp = self.bert_model(input_tok.clone().detach()[None,:], output_hidden_states = True)
                embedings.append(temp[1][-2])
            return torch.concat(embedings,dim=1).squeeze().mean(dim=0)

In [4]:
class MLP256(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(768, 512),
            torch.nn.GELU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(512, 512),
            torch.nn.GELU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(512, 512),
            torch.nn.GELU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(512, 256),
        )
    def forward(self, x):
        y = self.mlp(x)
        return y

## Set up BERT

In [5]:
tokenizer = AutoTokenizer.from_pretrained("neulab/codebert-cpp")
BERT = AutoModel.from_pretrained("neulab/codebert-cpp", add_pooling_layer = False)
BERT.eval()
if torch.cuda.is_available():
    bert_transform = BERTEmbeddingTransform(BERT,tokenizer, 'cuda')
else:
    bert_transform = BERTEmbeddingTransform(BERT,tokenizer, 'cpu')

Downloading (…)okenizer_config.json: 0.00B [00:00, ?B/s]

Downloading (…)olve/main/vocab.json: 0.00B [00:00, ?B/s]

Downloading (…)olve/main/merges.txt: 0.00B [00:00, ?B/s]

Downloading (…)/main/tokenizer.json: 0.00B [00:00, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of the model checkpoint at neulab/codebert-cpp were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Set up MLP

In [6]:
mlp = MLP256()
#model = BiLSTMVectorizer(768, 256)
if torch.cuda.is_available():
    mlp.to("cuda")
mlp.load_state_dict(torch.load("../models/MLP256_last.pth", map_location=torch.device('cpu')))
mlp.eval()

MLP256(
  (mlp): Sequential(
    (0): Linear(in_features=768, out_features=512, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): GELU(approximate='none')
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): GELU(approximate='none')
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=512, out_features=256, bias=True)
  )
)

# Load Data

In [15]:
!export KAGGLE_USERNAME=levbara
!export KAGGLE_KEY=01a0cd591f0703bd83e2beeb110172cd
!kaggle datasets download robertkhazhiev/codeforces-problems -f task_index_median_MLP256_03-04-23.bin -p ../data/interim

401 - Unauthorized


In [23]:
faiss_index = faiss.read_index('../data/interim/task_index_median_MLP256_03-04-23.bin')
keys_df = pd.read_csv("../data/interim/keys_df.csv", index_col=0)
pwt_df = pd.read_csv("../data/interim/codeforces-problems.csv")

In [24]:
keys_df.head()

Unnamed: 0,problem_url
0,/contest/1/problem/A
1,/contest/1/problem/B
2,/contest/1/problem/C
3,/contest/2/problem/A
4,/contest/2/problem/B


In [25]:
pwt_df["problem_url"] = pwt_df["problem_url"].apply(lambda x: x.replace("contests", "contest"))
pwt_df['problem_tags'] = pwt_df['problem_tags'].astype(str)
pwt_df.head()

Unnamed: 0.1,Unnamed: 0,contest,problem_name,problem_statement,problem_tags,rating,problem_url
0,0,325,A,You are given n rectangles. The corners of rec...,implementation,1500.0,/contest/325/problem/A
1,1,325,B,Daniel is organizing a football tournament. He...,"binarysearch,math",1800.0,/contest/325/problem/B
2,2,325,C,Piegirl has found a monster and a book about m...,"dfsandsimilar,graphs,shortestpaths",2600.0,/contest/325/problem/C
3,3,325,D,"In a far away land, there exists a planet shap...",dsu,2900.0,/contest/325/problem/D
4,4,325,E,Piegirl found the red button. You have one las...,"combinatorics,dfsandsimilar,dsu,graphs,greedy",2800.0,/contest/325/problem/E


In [26]:
merged_df = pd.merge(keys_df, pwt_df, how="left", on='problem_url')
merged_df.head()

Unnamed: 0.1,problem_url,Unnamed: 0,contest,problem_name,problem_statement,problem_tags,rating
0,/contest/1/problem/A,5.0,1.0,A,Theatre Square in the capital city of Berland ...,math,1000.0
1,/contest/1/problem/B,6.0,1.0,B,In the popular spreadsheets systems (for examp...,"implementation,math",1600.0
2,/contest/1/problem/C,7.0,1.0,C,Nowadays all circuses in Berland have a round ...,"geometry,math",2100.0
3,/contest/2/problem/A,6075.0,2.0,A,The winner of the card game popular in Berland...,"hashing,implementation",1500.0
4,/contest/2/problem/B,6076.0,2.0,B,"There is a square matrix n × n, consisting of ...","dp,math",2000.0


# Test FAISS

In [55]:
source_code = '#include <bits/stdc++.h>\n \nusing namespace std;\ntypedef long long ll;\n \nvector<vector<int>> adj,btree;\nint timer=1,bid=1;\nvector<int> vis,tin,low;\nset<pair<int,int>> edges;\n \nvoid bridge(int v,int p){\n    vis[v]=1;\n    tin[v] = low[v] = timer++;\n    for(auto u : adj[v]){\n        if(u == p) continue;\n        if(vis[u]){\n            low[v] = min(low[v],tin[u]);\n        }else{\n            bridge(u,v);\n            low[v] = min(low[v],low[u]);\n            if(low[u] > tin[v]){\n                // bridge   \n                edges.insert({min(u,v),max(u,v)});\n            }\n        }\n    }\n}\n \n \nvoid build(int v,int cur){\n    vis[v]=1;\n    for(auto u : adj[v]){\n        if(vis[u])continue;\n        if(edges.find({min(u,v),max(u,v)}) == edges.end()){\n            build(u,cur);\n        }else{\n            bid++;\n            btree[cur].push_back(bid);\n            btree[bid].push_back(cur);\n            build(u,bid);\n        }\n    }\n}\n \nint ans =0,s;\nvoid dfs(int v,int p,int cur){\n    cur++;\n    if(cur > ans){\n        ans = cur,s =v;\n    }\n    for(auto u : btree[v]){\n        if(u == p)continue;\n        dfs(u,v,cur);\n    }\n}\n \nvoid solve(){\n    int n,m;cin>>n>>m;\n    adj.resize(n+1),btree.resize(n+1);\n    for(int i=0;i<m;i++){\n        int x,y;cin>>x>>y;\n        adj[x].push_back(y);\n        adj[y].push_back(x);\n    }\n    vis.resize(n+1,0),tin.resize(n+1,0),low.resize(n+1,1e9);\n    bridge(1,-1);\n    vis.clear(),vis.resize(n+1,0);\n    build(1,1);\n    dfs(1,-1,0);\n    dfs(s, -1,0);\n    cout<<ans-1;\n}\n \n \nint main(){\n   ios_base::sync_with_stdio(0),cin.tie(0);\n   int t=1;//cin>>t;\n   while(t--) solve();\n}\n'

In [56]:
emb = mlp(bert_transform(source_code))
emb.size()

torch.Size([256])

In [57]:
result = faiss_index.search(emb.detach().numpy().reshape(-1,256), k=20)
result

(array([[243513.67, 261347.53, 311992.  , 321059.  , 330046.12, 357358.03,
         360204.75, 360420.5 , 362626.97, 368758.06, 370713.16, 386262.06,
         389857.4 , 390099.53, 402195.38, 409448.38, 409600.75, 412134.3 ,
         436164.  , 438096.03]], dtype=float32),
 array([[3071, 2413, 5716, 2013, 1062, 1069, 1605, 1920, 5093, 5136, 5255,
         4606,   61, 3404, 2256, 6950, 2258,  819, 6710, 2020]]))

In [58]:
for i in result[1][0]:
    print(merged_df.problem_url[i], merged_df.problem_tags[i])

/contest/687/problem/E dfsandsimilar,graphs
/contest/542/problem/E graphs,shortestpaths
/contest/1220/problem/E dfsandsimilar,dp,dsu,graphs,greedy,trees
/contest/449/problem/B graphs,greedy,shortestpaths
/contest/229/problem/B binarysearch,datastructures,graphs,shortestpaths
/contest/230/problem/D binarysearch,graphs,shortestpaths
/contest/346/problem/D dp,graphs,shortestpaths
/contest/427/problem/C dfsandsimilar,graphs,twopointers
/contest/1101/problem/D datastructures,dfsandsimilar,dp,numbertheory,trees
/contest/1108/problem/F binarysearch,dsu,graphs,greedy
/contest/1139/problem/E flows,graphmatchings,graphs
/contest/1000/problem/E dfsandsimilar,graphs,trees
/contest/14/problem/D dfsandsimilar,dp,graphs,shortestpaths,trees,twopointers
/contest/757/problem/F datastructures,graphs,shortestpaths
/contest/505/problem/D dfsandsimilar
/contest/1467/problem/E nan
/contest/506/problem/B dfsandsimilar,graphs
/contest/178/problem/B1 nan
/contest/1423/problem/B nan
/contest/450/problem/D graphs

## Test once again

In [51]:
my_code = """#include <iostream>
int main() {
    int n, t, t_last = 0;
    std::cin >> n;
    for (int i = 0; i < n; ++i) {
        std::cin >> t;
        if (t - t_last > 15) {
            std::cout << t_last + 15 << '\\n';
            return 0;
        }
        t_last = t;
    }
    if (90 - t_last > 15) {
        std::cout << t_last + 15 << '\\n';
    } else {
        std::cout << 90 << '\\n';
    }
    return 0;
}
"""

In [52]:
my_emb = mlp(bert_transform(my_code))

my_emb.size()

torch.Size([256])

In [53]:
my_recs = faiss_index.search(my_emb.detach().numpy().reshape(-1,256), k=40)
my_recs

(array([[ 522032.8 ,  555257.75,  606501.9 ,  607078.75,  628522.25,
          647745.8 ,  732096.1 ,  747057.4 ,  758509.  ,  767628.9 ,
          776300.5 ,  776618.4 ,  884603.6 ,  896551.  ,  898672.  ,
          904343.1 ,  911973.5 ,  942849.6 ,  978892.9 ,  990885.75,
          994977.1 ,  996726.06, 1018196.1 , 1024579.25, 1045882.2 ,
         1047523.44, 1049025.6 , 1066525.8 , 1069447.2 , 1092486.8 ,
         1096621.5 , 1096819.  , 1101231.5 , 1104593.8 , 1118921.5 ,
         1142213.5 , 1143113.6 , 1146463.  , 1160463.2 , 1169596.9 ]],
       dtype=float32),
 array([[ 880, 4022,  938, 1320,  666, 4161, 5290, 1919, 3989,  708, 1030,
         3980, 5307, 3007, 1790, 5245, 5390, 4596,  258, 4591, 5588, 1689,
         5185, 4472, 2038, 5383, 4891, 4727, 1584,  687, 2580, 3815, 1191,
         5171, 2017, 3852, 1126, 2665, 5099, 1348]]))

In [54]:
for i in my_recs[1][0]:
    print(merged_df.problem_url[i], merged_df.problem_tags[i])

/contest/192/problem/B bruteforce,implementation
/contest/879/problem/B datastructures,implementation
/contest/205/problem/A bruteforce,implementation
/contest/285/problem/B implementation
/contest/144/problem/A implementation
/contest/911/problem/A implementation
/contest/1145/problem/A nan
/contest/427/problem/B datastructures,implementation
/contest/872/problem/B implementation
/contest/155/problem/A bruteforce
/contest/222/problem/A bruteforce,implementation
/contest/870/problem/B greedy
/contest/1147/problem/C games
/contest/673/problem/A implementation
/contest/386/problem/A implementation
/contest/1138/problem/A binarysearch,greedy,implementation
/contest/1162/problem/E games
/contest/999/problem/A bruteforce,implementation
/contest/54/problem/A implementation
/contest/998/problem/A constructivealgorithms,implementation
/contest/1199/problem/A implementation
/contest/365/problem/B implementation
/contest/1119/problem/A greedy,implementation
/contest/977/problem/C sortings
/conte