In [2]:
import os
import json
import torch
import torch.nn as nn
from model import Seq2Seq
from utils import Example, convert_examples_to_features
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## We are defining all the needed functions here. 
def inference(data, model, tokenizer):
    # Calculate bleu
    eval_sampler = SequentialSampler(data)
    eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))

    model.eval()
    p = []
    for batch in eval_dataloader:
        batch = tuple(t.to('cpu') for t in batch)
        source_ids, source_mask = batch
        with torch.no_grad():
            preds = model(source_ids=source_ids, source_mask=source_mask)
            for pred in preds:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[: t.index(0)]
                text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
                p.append(text)
    return (p, source_ids.shape[-1])


def get_features(examples, tokenizer):
    features = convert_examples_to_features(
        examples, tokenizer, stage="test"
    )
    all_source_ids = torch.tensor(
        [f.source_ids[: 256] for f in features], dtype=torch.long
    )
    all_source_mask = torch.tensor(
        [f.source_mask[: 256] for f in features], dtype=torch.long
    )
    return TensorDataset(all_source_ids, all_source_mask)


def build_model(model_class, config, tokenizer):
    encoder = model_class(config=config)
    decoder_layer = nn.TransformerDecoderLayer(
        d_model=config.hidden_size, nhead=config.num_attention_heads
    )
    decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
    model = Seq2Seq(
        encoder=encoder,
        decoder=decoder,
        config=config,
        beam_size=10,
        max_length=128,
        sos_id=tokenizer.cls_token_id,
        eos_id=tokenizer.sep_token_id,
    )

    model.load_state_dict(
        torch.load(
            "pytorch_model.bin",
            map_location=torch.device("cpu"),
        ),
        strict=False,
    )
    return model

In [4]:
config = RobertaConfig.from_pretrained("microsoft/codebert-base")
tokenizer = RobertaTokenizer.from_pretrained(
    "microsoft/codebert-base", do_lower_case=False
)

model = build_model(
    model_class=RobertaModel, config=config, tokenizer=tokenizer
).to('cpu')

In [5]:
for i in os.listdir('files-for-summarization'):
    with open('files-for-summarization/user_code.py', 'r') as f:
        body = f.read()
    example = [Example(source=body, target=None)]
    message, length = inference(get_features(example, tokenizer), model, tokenizer)
    print('message: ', message)
    f.close()


message:  ['Checks that all numbers in a list are equal .']


In [None]:
import inspect
import importlib


def function_dissimator(module_name, path_to_module):
    spec = importlib.util.spec_from_file_location(module_name, path_to_module)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    funcs = []
    for name, value in vars(module).items():
        if name.startswith("_") or not callable(value):
            continue
        doc = inspect.getdoc(value)
        code = inspect.getsource(value).split(":", maxsplit=1)[1]
        funcs.append({"name": name, "docstring": doc, "body": code})

    return funcs

In [6]:
import os
with open('summary.txt', 'w') as summary:
    for i in os.listdir('files-from-zip'):
        if os.path.splitext(i)[1] == '.py':
            summary.write(f'\nFile name: {i}\n')
            for func in function_dissimator(os.path.splitext(i)[0], f'files-from-zip/{i}'):
                summary.write(f'Function name: {func["name"]}\n')
                body = func["body"]
                example = [Example(source=body, target=None)]
                message, length = inference(get_features(example, tokenizer), model, tokenizer)
                summary.write(f'Summary: {message}\n\n')
    
summary.close()