In [1]:
# !pip install cloud-tpu-client==0.10 torch==1.12.0 torchvision==0.13.0 transformers https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl

!pip install transformers
!pip install datasets
!apt install libomp-dev --yes
!pip install pytorch-lightning==1.8.3
!pip install neptune-client

device = 'cuda'
if device == 'cuda':
    !pip install faiss-gpu -q
else:
    !pip install faiss-cpu -q




The following additional packages will be installed:
  libomp-10-dev libomp5-10
Suggested packages:
  libomp-10-doc
The following NEW packages will be installed:
  libomp-10-dev libomp-dev libomp5-10
0 upgraded, 3 newly installed, 0 to remove and 93 not upgraded.
Need to get 351 kB of archives.
After this operation, 2281 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu focal/universe amd64 libomp5-10 amd64 1:10.0.0-4ubuntu1 [300 kB]
Get:2 http://archive.ubuntu.com/ubuntu focal/universe amd64 libomp-10-dev amd64 1:10.0.0-4ubuntu1 [47.7 kB]
Get:3 http://archive.ubuntu.com/ubuntu focal/universe amd64 libomp-dev amd64 1:10.0-50~exp1 [2824 B]
Fetched 351 kB in 0s (978 kB/s)

7[0;23r8[1ASelecting previously unselected package libomp5-10:amd64.
(Reading database ... 108827 files and directories currently installed.)
Preparing to unpack .../libomp5-10_1%3a10.0.0-4ubuntu1_amd64.deb ...
7[24;0f[42m[30mProgress: [  0%][49m[39m [.....

In [2]:
import numpy as np
import torch
from torch.nn import functional as F
from torchvision import transforms
import pytorch_lightning as pl

from datasets import load_dataset, concatenate_datasets, load_from_disk
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from enum import Enum, auto
from tqdm.notebook import tqdm

    
import faiss
import faiss.contrib.torch_utils

In [3]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

FLAGS = {
    'batch_size_per_device': 5,
    'max_epochs': 3,
    'es_min_delta': 0.1,
    'num_devices': 2,
    'num_workers': 2,
}
FLAGS['batch_size'] = FLAGS['batch_size_per_device'] * FLAGS['num_devices']

In [4]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/498 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]

In [5]:
class SeqType(Enum):
  CODE = auto()
  DOC = auto()

    
class TokenizeTransform(object):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        return self.create_one_batch(batch)

    def create_one_batch(self, batch):
        tokens_batch = self.get_formatted_input(batch, SeqType.CODE)
        batch_encoding = self.tokenizer(tokens_batch, padding='max_length', return_token_type_ids=True, truncation=True, max_length=512)
        code_tokens_ids = batch_encoding.input_ids
        code_token_type_ids = batch_encoding.token_type_ids
        code_attention_mask = batch_encoding.attention_mask
        
        tokens_batch = self.get_formatted_input(batch, SeqType.DOC)
        batch_encoding = self.tokenizer(tokens_batch, padding='max_length', return_token_type_ids=True, truncation=True, max_length=512)
        doc_tokens_ids = batch_encoding.input_ids
        doc_token_type_ids = batch_encoding.token_type_ids
        doc_attention_mask = batch_encoding.attention_mask
        
        return {'code_tokens_ids': code_tokens_ids,
                'code_token_type_ids': code_token_type_ids,
                'code_attention_mask': code_attention_mask,
                
                'doc_tokens_ids': doc_tokens_ids,
                'doc_token_type_ids': doc_token_type_ids,
                'doc_attention_mask': doc_attention_mask,
               }

    def get_formatted_input(self, batch, seq_type):
        if seq_type == SeqType.CODE:
            return self.get_formatted_input_with_f(batch, TokenizeTransform.get_batched_tokens4code)
        elif seq_type == SeqType.DOC:
            return self.get_formatted_input_with_f(batch, TokenizeTransform.get_batched_tokens4docstring)
        else:
            raise Exception("Incorrect sequence type")
            
    def get_formatted_input_with_f(self, batch, get_tokens):
        return [self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token + code_tokens + self.tokenizer.sep_token \
                   for doc_tokens, code_tokens in zip(*get_tokens(batch))]

    def get_batched_tokens4code(batch):
        code_tokens = TokenizeTransform.get_code_tokens(batch)
        doc_tokens = [''] * len(code_tokens)
        return doc_tokens, code_tokens

    def get_batched_tokens4docstring(batch):
        doc_tokens = TokenizeTransform.get_docstring_tokens(batch)
        code_tokens = [''] * len(doc_tokens) 
        return doc_tokens, code_tokens

    def get_docstring_tokens(batch):
        return [' '.join(tokens) for tokens in batch['docstring_tokens']]
    
    def get_code_tokens(batch):
        return [' '.join(tokens) for tokens in batch['code_tokens']]
    
    
tokenize_transform = TokenizeTransform(tokenizer)

In [6]:
DATASETS_ON_DISK = False
if DATASETS_ON_DISK:
    train_dataset = load_from_disk("train_dataset").with_format("torch")
    valid_dataset = load_from_disk("valid_dataset").with_format("torch")
    test_dataset = load_from_disk("test_dataset").with_format("torch")
else:
    dataset = load_dataset("code_x_glue_ct_code_to_text", 'python')
    train_data, valid_data, test_data = dataset['train'], dataset['validation'], dataset['test']
    columns = train_data.column_names

    train_dataset = train_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")
    valid_dataset = valid_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")
    test_dataset = test_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    persistent_workers=True,
    pin_memory=True,
    drop_last=True)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    persistent_workers=True,
    pin_memory=True,
    drop_last=True)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    persistent_workers=True,
    pin_memory=True,
    drop_last=True)

Downloading builder script:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/961 [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/475 [00:00<?, ?B/s]

Downloading and preparing dataset code_x_glue_ct_code_to_text/python (download: 909.14 MiB, generated: 869.00 MiB, post-processed: Unknown size, total: 1.74 GiB) to /root/.cache/huggingface/datasets/code_x_glue_ct_code_to_text/python/0.0.0/f8b7e9d51f609a87e7ec7c7431706d4ee0b402e3398560410313d4acc67060a0...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/941M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.4M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/251820 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13914 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/14918 [00:00<?, ? examples/s]

Dataset code_x_glue_ct_code_to_text downloaded and prepared to /root/.cache/huggingface/datasets/code_x_glue_ct_code_to_text/python/0.0.0/f8b7e9d51f609a87e7ec7c7431706d4ee0b402e3398560410313d4acc67060a0. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/25182 [00:00<?, ?ba/s]

  0%|          | 0/1392 [00:00<?, ?ba/s]

  0%|          | 0/1492 [00:00<?, ?ba/s]

In [7]:
class FaissKNeighbors:
    def __init__(self, is_cuda):
        self.index = None
        self.is_cuda = is_cuda

    def fit(self, X):
        self.index = faiss.IndexFlatL2(X.shape[1])
        if self.is_cuda:
            res = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
        self.index.add(X)

    def predict(self, X, k):
        distances, indices = self.index.search(X, k=k)
        return indices

In [8]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.CrossEntropyLoss()

        self.model = AutoModel.from_pretrained("microsoft/codebert-base")
        for p in self.model.parameters():
            p.requires_grad = True

    
    def forward(self, batch):
        tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
        code_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
        doc_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

        scores = (doc_embs[:,None,:]*code_embs[None,:,:]).sum(-1)
        return scores
    
    
    def training_step(self, batch, batch_nb):
        tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
        code_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
        doc_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

        scores = (doc_embs[:,None,:]*code_embs[None,:,:]).sum(-1)
        
        loss = self.loss_fn(scores, torch.arange(len(batch['code_tokens_ids']), device=scores.device))
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    
    def test_step(self, batch, batch_nb):
        tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
        code_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
        doc_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        scores = (doc_embs[:,None,:]*code_embs[None,:,:]).sum(-1)
        loss = self.loss_fn(scores, torch.arange(len(batch['code_tokens_ids']), device=scores.device))
        self.log('test_loss', loss, prog_bar=True)
        return {'code_embs': code_embs, 'doc_embs': doc_embs}
    
    
    def validation_step(self, batch, batch_idx):
        tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
        code_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
        doc_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        
        scores = (doc_embs[:,None,:]*code_embs[None,:,:]).sum(-1)
        loss = self.loss_fn(scores, torch.arange(len(batch['code_tokens_ids']), device=scores.device))
        self.log('valid_loss', loss, prog_bar=True)
        return {'code_embs': code_embs, 'doc_embs': doc_embs}
        
    
    def test_epoch_end(self, outputs):
        mrr, exact_match, top3 = self.calculate_metrics(outputs)
        self.log_dict({'test_mrr': mrr,
                      'test_exact_match': exact_match,
                      'test_top3': top3},
                      prog_bar=True,
                      sync_dist=True)
        
        
    def validation_epoch_end(self, outputs):
        mrr, exact_match, top3 = self.calculate_metrics(outputs)
        self.log_dict({'valid_mrr': mrr,
                      'valid_exact_match': exact_match,
                      'valid_top3': top3},
                      prog_bar=True,
                      sync_dist=True)
        
        
    def calculate_metrics(self, outputs):
        code_embs = torch.cat([batch['code_embs'] for batch in outputs], dim=0).to(torch.float32)
        doc_embs = torch.cat([batch['doc_embs'] for batch in outputs], dim=0).to(torch.float32)

        faiss = FaissKNeighbors(is_cuda=True)

        k = 1000
        mrrs = torch.zeros(0)
        exact_matches = torch.zeros(0)
        top3 = torch.zeros(0)
        for beg_idx in tqdm(range(0, len(code_embs), k), desc='Calculating Metrics:'):
            if beg_idx + k > len(code_embs):
                k = len(code_embs) - beg_idx
            doc_embs_subset = doc_embs[beg_idx:beg_idx + k]
            code_embs_subset = code_embs[beg_idx:beg_idx + k]
            faiss.fit(code_embs_subset)
            
            preds = faiss.predict(doc_embs_subset, k=k).cpu()
            mrr_ = self.calculate_mrr(preds, k)
            exact_match_ = self.calculate_hit_rate(preds, 1)
            top3_ = self.calculate_hit_rate(preds, 3)
            
            mrrs = torch.cat((mrrs, mrr_.view(1)))
            exact_matches = torch.cat((exact_matches, exact_match_.view(1)))
            top3 = torch.cat((top3, top3_.view(1)))
 
        mrr = torch.mean(mrrs).item()
        exact_match = torch.mean(exact_matches).item()
        top3_ = torch.mean(top3).item()
        return mrr, exact_match, top3_

    def calculate_mrr(self, preds, k):
        targets = torch.unsqueeze(torch.tensor(range(k)), 1).expand(-1, k)
        reciprocal_ranks = 1 / (torch.argwhere(torch.eq(preds, targets))[:,1] + 1)
        mrr_ = torch.mean(reciprocal_ranks).cpu()
        return mrr_
    
    def calculate_hit_rate(self, preds, k):
        targets = torch.unsqueeze(torch.tensor(range(len(preds))), 1).expand(-1, k)
        preds = preds[:, :k]
        return torch.eq(targets, preds).any(1).sum() / len(preds)
        
    
    def configure_optimizers(self):
        learning_rate = 1e-5
        return torch.optim.Adam(self.parameters(), lr=learning_rate)

In [9]:
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import NeptuneLogger


early_stopping = EarlyStopping('valid_mrr', patience=1, mode='max', min_delta=FLAGS['es_min_delta'])
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints", save_top_k=1, monitor="valid_mrr", mode='max')

neptune_logger = NeptuneLogger(
    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI5ZTU5NTY1Yy00YzIxLTQ5NGYtOGRkNi03MzQ1NzUyZDc1MzIifQ==",  # replace with your own
    project="gfx73/code-search",  # format "<WORKSPACE/PROJECT>"
    tags=["training on val"],
)

model = Model()
trainer = pl.Trainer(precision=16,
                     accelerator="gpu",
                     strategy="ddp_fork",
                     devices=FLAGS['num_devices'],
                     accumulate_grad_batches=64 // FLAGS['batch_size'],
                     callbacks=[TQDMProgressBar(refresh_rate=1), early_stopping, checkpoint_callback],
                     logger=neptune_logger,
                     max_epochs=FLAGS['max_epochs'],
                     log_every_n_steps=10,
                     check_val_every_n_epoch=1,
)

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

In [10]:
trainer.validate(model, valid_loader)
trainer.fit(model, train_loader, valid_loader)



Validation: 0it [00:00, ?it/s]



Calculating Metrics::   0%|          | 0/7 [00:00<?, ?it/s]

https://app.neptune.ai/gfx73/code-search/e/COD-10
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.




Training: 0it [00:00, ?it/s]



https://app.neptune.ai/gfx73/code-search/e/COD-11
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


Validation: 0it [00:00, ?it/s]

Calculating Metrics::   0%|          | 0/7 [00:00<?, ?it/s]

Validation: 0it [00:00, ?it/s]

Calculating Metrics::   0%|          | 0/7 [00:00<?, ?it/s]

In [11]:
trainer.test(model, test_loader)



Testing: 0it [00:00, ?it/s]



Calculating Metrics::   0%|          | 0/8 [00:00<?, ?it/s]

https://app.neptune.ai/gfx73/code-search/e/COD-12
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


[{'test_loss': 0.2570439577102661,
  'test_mrr': 0.856471061706543,
  'test_exact_match': 0.7842847108840942,
  'test_top3': 0.9191597104072571}]