In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 48      # e.g. GPT-2 XL
    n_head: int = 25
    n_embd: int = 1600

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
            .view(1, 1, config.block_size, config.block_size)
        )

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)

        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    """
    A GPT-like model that only stores the hidden state *after*
    'skip_up_to - 1' layers (the "Block k" state).

    For skipping:
      - if we detect a copy scenario, we load the cached partial
        hidden state from t_matched, run the last layers only.
      - else, we run all layers and store the partial state.
    """

    def __init__(self, config, skip_up_to=43):
        super().__init__()
        self.config = config
        self.skip_up_to = skip_up_to   # number of layers to skip
        self.transformer = nn.ModuleDict({
            'wte':  nn.Embedding(config.vocab_size, config.n_embd),
            'wpe':  nn.Embedding(config.block_size, config.n_embd),
            'h':    nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # cache_partial: only store the hidden state after skip_up_to - 1 layers
        # so cache_partial[(batch_idx, pos)] = Tensor shape (B, T, n_embd)
        self.cache_partial = {}

    # def forward(self, input_ids, pos_matched=None):
    #     """
    #     input_ids: (B, T), for simplicity assume B=1
    #     pos_matched: int or None.
    #       If None => normal forward on entire sequence
    #       If not None => "copy" scenario.
    #     """
    #     B, T = input_ids.shape
    #     device = input_ids.device

    #     # Basic embed
    #     pos = torch.arange(0, T, device=device)
    #     full_x = self.transformer['wte'](input_ids) + self.transformer['wpe'](pos)

    #     if pos_matched is None:
    #         # Normal forward for all T tokens from layer 0..n_layer
    #         x = full_x
    #         for block in self.transformer['h']:
    #             x = block(x)

    #     else:
    #         # "Copy" scenario
    #         # 1) separate the sequence into first T-1 tokens vs the newly added token
    #         x_trunc = full_x[:, :-1, :]  # shape = (1, T-1, n_embd)

    #         # 2) compute from layer 0..skip_up_to on the truncated x
    #         for layer_idx in range(self.skip_up_to + 1):
    #             x_trunc = self.transformer['h'][layer_idx](x_trunc)

    #         # 3) get matched hidden from x_trunc for the new token
    #         #    note that pos_matched must be < T-1, so we can do x_trunc[:, pos_matched, :]
    #         matched_hid = x_trunc[:, pos_matched, :]  # shape (1, n_embd)
    #         matched_hid = matched_hid.unsqueeze(1)    # => shape (1, 1, n_embd)

    #         # 4) cat matched hidden to x_trunc => new shape (1, T, n_embd)
    #         x = torch.cat([x_trunc, matched_hid], dim=1)

    #         # 5) continue from layer skip_up_to+1..end on the full (1, T, n_embd)
    #         for layer_idx in range(self.skip_up_to + 1, self.config.n_layer):
    #             x = self.transformer['h'][layer_idx](x)

    #     # final layer norm + logits
    #     x = self.transformer['ln_f'](x)
    #     logits = self.lm_head(x)  # shape (B, T, vocab_size)
    #     return logits

    def forward(self, input_ids, batch_idx=0, pos_new=None, pos_matched=None):
        """
        input_ids: shape (B, T)
        pos_new: which position we're generating if single-token stepping
        pos_matched: if copying from matched position
        """
        B, T = input_ids.shape
        device = input_ids.device

        # token + pos embedding
        positions = torch.arange(0, T, device=device)
        x = self.transformer['wte'](input_ids) + self.transformer['wpe'](positions)

        # Skip scenario?
        if pos_matched is not None and (batch_idx, pos_matched) in self.cache_partial:
            # Reuse the partial hidden state from pos_matched
            # instead of computing layers [0..skip_up_to-1].
            partial_hid = self.cache_partial[(batch_idx, pos_matched)].to(device)
            x = partial_hid
            start_layer = self.skip_up_to
        else:
            # we have to compute all layers or up to skip_up_to
            start_layer = 0

        # Actually run the layers from start_layer to end
        for layer_idx in range(start_layer, self.config.n_layer):
            block = self.transformer['h'][layer_idx]
            x = block(x)

            # If layer_idx == skip_up_to-1, store partial
            if layer_idx == self.skip_up_to - 1 and pos_new is not None:
                # store in CPU to reduce GPU memory
                self.cache_partial[(batch_idx, pos_new)] = x.detach().cpu()

        x = self.transformer['ln_f'](x)
        logits = self.lm_head(x)
        return logits

    @classmethod
    def from_pretrained(cls, model_type, skip_up_to=43):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config, skip_up_to=skip_up_to)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

In [2]:
from transformers import GPT2Tokenizer

model_name = 'gpt2-xl'
num_layer_cal = 10
skip_up_to = 47-num_layer_cal

device1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

model_copy = GPT.from_pretrained(model_name, skip_up_to=0)
model_copy = model_copy.to(device1)
model_copy.eval()

  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained gpt: gpt2-xl


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (h): ModuleList(
      (0-47): 48 x Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=1600, out_features=4800, bias=True)
          (c_proj): Linear(in_features=1600, out_features=1600, bias=True)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=1600, out_features=6400, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=6400, out_features=1600, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)

In [11]:
code = """def breadth_first_search(startnode, goalnode):
    queue = Queue()
    queue.append(startnode)

    nodesseen = set()
    nodesseen.add(startnode)

    while True:
        node = queue.popleft()

        if node is goalnode:
            return True
        else:
            queue.extend(node for node in node.successors if node not in nodesseen)
            nodesseen.update(node.successors)

    return False"""

prompt = f"Given the following code is incorrect:\n{code}\nCorrected code:"
code_ids = tokenizer.encode(code, return_tensors='pt')
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device1)

# generate 50 tokens and then decode
steps = 20 + code_ids.size(1)
for step in range(steps):
    logits = model_copy(input_ids)
    next_token = torch.argmax(logits[0, -1, :])
    input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1)

print(tokenizer.decode(input_ids[0]))


Given the following code is incorrect:
def breadth_first_search(startnode, goalnode):
    queue = Queue()
    queue.append(startnode)

    nodesseen = set()
    nodesseen.add(startnode)

    while True:
        node = queue.popleft()

        if node is goalnode:
            return True
        else:
            queue.extend(node for node in node.successors if node not in nodesseen)
            nodesseen.update(node.successors)

    return False
Corrected code:
def breadth_first_search(startnode, goalnode):
    queue = Queue()
    queue.append(startnode)
     nodesseen = set()
     nodesseen.add(startnode)
     while True:
        node = queue.popleft()
          if node is goalnode:
             return True
         else:
           queue.extend(node for node in node.successors if node not in nodesseen)
            nodesseen.update(node.successors)
The above code is correct, but it is not the most efficient. The code is more


In [24]:
logits[:,-1,:]

tensor([[ 0.2359, -0.3648, -4.0851,  ..., -9.3595, -4.7799, -0.3083]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [28]:
probs = torch.softmax(logits[:, -1, :], dim=-1)
top_vals, top_indices = probs.topk(5, dim=-1)
probs, top_vals, top_indices

(tensor([[1.2318e-05, 6.7554e-06, 1.6366e-07,  ..., 8.3809e-10, 8.1692e-08,
          7.1478e-06]], device='cuda:0', grad_fn=<SoftmaxBackward0>),
 tensor([[0.0430, 0.0409, 0.0212, 0.0182, 0.0171]], device='cuda:0',
        grad_fn=<TopkBackward0>),
 tensor([[  517,   407,   845,    25, 30904]], device='cuda:0'))

In [29]:
top_indices[0].tolist()

[517, 407, 845, 25, 30904]

In [2]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_top_k(logits, top_k=5):
    """
    logits: (B, T, vocab_size)
    Returns a list of top-k token IDs for the last position, e.g. [id1, id2,...].
    """
    last_logits = logits[:, -1, :]       # shape (B, vocab_size)
    probs = torch.softmax(last_logits, dim=-1)
    top_vals, top_indices = probs.topk(top_k, dim=-1)
    # top_indices is shape (B, top_k). For B=1, we do top_indices[0].tolist().
    return top_indices[0].tolist()

def detect_ngram_copy(seq_ids: torch.Tensor, n=3, skip_up_to=43):
    """
    Minimal function that tries to find n-gram copy scenario
    (just a placeholder – adapt to your real logic)
    """
    T = seq_ids.size(1)  # shape (B=1, T)
    if T < n:
        return None, None
    # 1) last token
    last_token = seq_ids[0, -1].item()
    # 2) find earlier positions of last_token
    possible_pos = (seq_ids[0, :-1] == last_token).nonzero().view(-1)
    if possible_pos.numel() == 0:
        return None, None
    # 3) check (n-1) context
    n_minus_1 = n - 1
    context_needed = seq_ids[0, -(n_minus_1+1):-1]  # last n-1 tokens
    matched_pos = None
    for pos in reversed(possible_pos):
        if pos >= n_minus_1:
            candidate = seq_ids[0, pos-n_minus_1:pos]
            if torch.all(candidate == context_needed):
                matched_pos = pos.item()
                break
    if matched_pos is None:
        return None, None
    else:
        return matched_pos, skip_up_to

In [3]:
from transformers import GPT2Tokenizer

model_name = 'gpt2-xl'
num_layer_cal = 10
skip_up_to = 47-num_layer_cal

device1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

model_copy = GPT.from_pretrained(model_name, skip_up_to=skip_up_to)
model_copy = model_copy.to(device1)
model_copy.eval()

device2 = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model_original = GPT.from_pretrained(model_name, skip_up_to=0)
model_original = model_original.to(device2)
model_original.eval()

  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained gpt: gpt2-xl
loading weights from pretrained gpt: gpt2-xl


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (h): ModuleList(
      (0-47): 48 x Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=1600, out_features=4800, bias=True)
          (c_proj): Linear(in_features=1600, out_features=1600, bias=True)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=1600, out_features=6400, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=6400, out_features=1600, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)

In [4]:
from datasets import load_from_disk
from tqdm import tqdm

# Load the dataset from disk
subset = load_from_disk("english_insertions")
prompt_list = []

base_sents = subset['train']['base_sentence'][:1000]
phrases = subset['train']['phrase'][:1000]
edited_sents = subset['train']['edited_sentence'][:1000]

import gc
del subset
gc.collect()

0

In [7]:
base_sents[0]

'Dadaji lost his mother and took to living with his maternal uncle Narayan Dhurmaji .'

In [6]:
phrases[0]

'In the meanwhile ,'

In [7]:
import time
import torch
import random
import numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
from collections import defaultdict

seed = 5
seed_everything(seed)

extra_steps = 20
max_steps = 1024             
k = 10
n = 5
info_lst = []


for code in tqdm(file_contents):
    prompt = f"Given the following code is incorrect:\n{code}\nCorrected code:"
    code_ids = tokenizer.encode(code, return_tensors='pt')
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    steps = extra_steps + code_ids.size(1)
    steps = min(steps, max_steps)

    dict_pred_info = defaultdict(dict)

    # -------------------------------------
    # 3a) Copy-mech generation
    # -------------------------------------
    copy_ids = input_ids.clone().to(device1)  # shape (1, initial_length)
    # (Optional) fill partial states for the entire existing prompt if you want 
    # to skip from tokens inside the prompt. For example:
    for p in range(copy_ids.size(1)):
        sub_ids = copy_ids[:, :p+1]
        # pos_new=p, pos_matched=None => no skip
        model_copy.forward(sub_ids, batch_idx=0, pos_new=p, pos_matched=None)
        
    flag = False
    for step_i in range(steps):
        t0 = time.time()

        # detect copy
        t_matched, skipv = detect_ngram_copy(copy_ids, n=n, skip_up_to=43)
        
        # if t_matched is not None and copy_ids[0, t_matched].item() in [7783]:
        #     flag = True
        #     t_matched = None
        
        # if flag:
        #     t_matched = None

        logits = model_copy.forward(
            input_ids=copy_ids,
            batch_idx=0,
            pos_new=copy_ids.shape[1] - 1,  # new position
            pos_matched=t_matched
        )
        top_k_list = get_top_k(logits, top_k=k)

        next_token_id = top_k_list[0]
        next_token_tensor = torch.tensor([[next_token_id]], device=device1)
        copy_ids = torch.cat([copy_ids, next_token_tensor], dim=1)

        elapsed_copy = time.time() - t0
        dict_pred_info[step_i]['copy'] = top_k_list
        dict_pred_info[step_i]['copy_time'] = elapsed_copy

    # Clear partial cache for next sample
    model_copy.cache_partial.clear()

    # -------------------------------------
    # 3b) Original model generation (skip_up_to=0)
    # -------------------------------------
    original_ids = input_ids.clone().to(device2)
    # likewise fill partial states for the entire prompt 
    for p in range(original_ids.size(1)):
        sub_ids = original_ids[:, :p+1]
        model_original.forward(sub_ids, batch_idx=0, pos_new=p, pos_matched=None)

    for step_i in range(steps):
        t0 = time.time()

        # always do full pass (pos_matched=None)
        logits = model_original.forward(
            input_ids=original_ids,
            batch_idx=0,
            pos_new=original_ids.shape[1] - 1,
            pos_matched=None
        )
        top_k_list = get_top_k(logits, top_k=k)

        next_token_id = top_k_list[0]
        next_token_tensor = torch.tensor([[next_token_id]], device=device2)
        original_ids = torch.cat([original_ids, next_token_tensor], dim=1)

        elapsed_orig = time.time() - t0
        dict_pred_info[step_i]['original'] = top_k_list
        dict_pred_info[step_i]['original_time'] = elapsed_orig

    # Clear partial cache
    model_original.cache_partial.clear()

    info_lst.append(dict_pred_info)

 10%|█         | 4/40 [01:19<11:57, 19.93s/it]


KeyboardInterrupt: 

In [8]:
def jaccard_similarity(list1, list2):
    set1 = set(list1)
    set2 = set(list2)
    intersection = len(set1.intersection(set2))
    union = len(set1) + len(set2) - intersection
    return  intersection / union

jcc_ult = []
acc_ult = []
tpt_copy = []
tpt_orig = []

for data in info_lst:
    total_copy_time = 0
    total_orig_time = 0
    acc_lst = []
    jc_lst = []
    for step in data.keys():
        total_copy_time += data[step]['copy_time']
        total_orig_time += data[step]['original_time']
        copy = data[step]['copy']
        original = data[step]['original']

        jaccard_score = jaccard_similarity(copy, original)
        jc_lst.append(jaccard_score)

        acc_score = 1 if copy[0] == original[0] else 0
        acc_lst.append(acc_score)

    # time per token - tpt
    tpt_copy.append(total_copy_time / steps)
    tpt_orig.append(total_orig_time / steps)

    jcc_ult.append(jc_lst)
    acc_ult.append(acc_lst)

In [23]:
# tpt_copy / tpt_orig by row 
tpt_ratio = [tpt_copy[i] / tpt_orig[i] for i in range(len(tpt_copy))]

In [9]:
acc_ult[1]

[1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [10]:
sum_tpt_copy = sum(tpt_copy)
sum_tpt_orig = sum(tpt_orig)
sum_tpt_copy / sum_tpt_orig

0.754166875720513

In [11]:
def cal_avg(lsts):
    avg_lst = []
    for lst in lsts:
        avg_lst.append(sum(lst) / len(lst))
    return sum(avg_lst) / len(avg_lst)

avg_jcc = cal_avg(jcc_ult)
avg_acc = cal_avg(acc_ult)
avg_jcc, avg_acc

(0.34435988335702306, 0.36635356332886493)

In [10]:
acc_ult[3]

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1]

English Insertion

In [None]:
from datasets import load_from_disk
from tqdm import tqdm

# Load the dataset from disk
subset = load_from_disk("english_insertions")
prompt_list = []

base_sents = subset['train']['base_sentence'][:1000]
phrases = subset['train']['phrase'][:1000]
edited_sents = subset['train']['edited_sentence'][:1000]

import gc
del subset
gc.collect()