In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from typet5.data import GitRepo, ModuleRemapUnpickler
from typet5.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from typet5.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
rename_module = lambda n: "typet5.data" if n == "typet5.data_prepare" else n
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = ModuleRemapUnpickler(f, rename_module).load()

In [2]:
# loading pre-trained model and tokenizer

model_dir = datadir/"checkpoints/saved/SPOT-CodeT5-with_margin/"

import torch
from transformers import (
    DataCollatorForSeq2Seq,
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
from transformers.models.t5 import T5ForConditionalGeneration

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_dir)
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
    model_dir
).to(device)
max_target_length = 128



In [25]:
from typet5.data import mask_type_annots, output_ids_as_types, tokenize_masked

test_code = """
@dataclass
class GitRepo:
    author: str
    name: str
    url: str
    stars: int
    forks: int

    def authorname(self):
        return self.author + "__" + self.name

    def repo_dir(self, repos_dir: Path) -> Path:
        return repos_dir / "downloaded" / self.authorname()

    def download(self, repos_dir: Path, timeout=None) -> bool:
        pass
"""


def run_model(code: str, num_beams=16):
    tks = tokenize_masked(mask_type_annots((Path('no_source'), code)), tokenizer, device)
    input_ids = tks["input_ids"]
    with torch.no_grad():
        loss = model.forward(**tks).loss
        dec = model.generate(
            input_ids,
            max_length=max_target_length,
            num_beams=num_beams,
            # do_sample=True,
        )[0]
    return {
        "loss": loss,
        "predicted_types": output_ids_as_types(dec, tokenizer),
        "labels": output_ids_as_types(tks["labels"][0], tokenizer),
        "generation": tokenizer.decode(dec),
        "input_ids": input_ids[0],
        "output_ids": dec,
    }


result = run_model(test_code, num_beams=10)

In [27]:
# Step 1: Replace all types to predict with special tokens
print(tokenizer.decode(result['input_ids']))

<s>
@dataclass
class GitRepo:
    author:<extra_id_0>
    name:<extra_id_1>
    url:<extra_id_2>
    stars:<extra_id_3>
    forks:<extra_id_4>

    def authorname(self):
        return self.author + "__" + self.name

    def repo_dir(self, repos_dir:<extra_id_5>) -><extra_id_6>:
        return repos_dir / "downloaded" / self.authorname()

    def download(self, repos_dir:<extra_id_7>, timeout=None) -><extra_id_8>:
        pass
</s>


In [28]:
# Step 2: Tokenize using Byte Pair Encoding (BPE)
print(tokenizer.convert_ids_to_tokens(result['input_ids']))

['<s>', 'Ċ', '@', 'data', 'class', 'Ċ', 'class', 'ĠGit', 'Repo', ':', 'Ċ', 'ĠĠĠ', 'Ġauthor', ':', '<extra_id_0>', 'Ċ', 'ĠĠĠ', 'Ġname', ':', '<extra_id_1>', 'Ċ', 'ĠĠĠ', 'Ġurl', ':', '<extra_id_2>', 'Ċ', 'ĠĠĠ', 'Ġstars', ':', '<extra_id_3>', 'Ċ', 'ĠĠĠ', 'Ġfor', 'ks', ':', '<extra_id_4>', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġauthor', 'name', '(', 'self', '):', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġself', '.', 'author', 'Ġ+', 'Ġ"__', '"', 'Ġ+', 'Ġself', '.', 'name', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġrepo', '_', 'dir', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '<extra_id_5>', ')', 'Ġ->', '<extra_id_6>', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġrepos', '_', 'dir', 'Ġ/', 'Ġ"', 'down', 'loaded', '"', 'Ġ/', 'Ġself', '.', 'author', 'name', '()', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġdownload', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '<extra_id_7>', ',', 'Ġtimeout', '=', 'None', ')', 'Ġ->', '<extra_id_8>', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġpass', 'Ċ', '</s>']


In [29]:
# Step 3: Let model predict a sequence of types using BPE
print(tokenizer.convert_ids_to_tokens(result['output_ids']))

['<pad>', '<s>', '<extra_id_0>', 'str', '<extra_id_1>', 'str', '<extra_id_2>', 'str', '<extra_id_3>', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '<extra_id_4>', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '<extra_id_5>', 'Path', '<extra_id_6>', 'Path', 'Ġ.', 'ĠPath', '<extra_id_7>', 'Path', 'Ġ.', 'ĠPath', '<extra_id_8>', 'Path', 'Ġ.', 'ĠPath', 'Ġ[', 'Ġstr', ']', '</s>']


In [30]:
# Step 4: Extract the predicted types
print(result['predicted_types'])

[str, str, str, Any, Any, Path, Path.Path, Path.Path, Path.Path[str]]
