In [21]:
def peft_gpt_baselines():
    from GPT2 import PeftGPT2Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    import torch
    from transformers import AutoModelForCausalLM
    import warnings

    device = 'cuda'
    torch.set_default_device(device)

    train = train_loader(batch_size=4)
    dev = dev_loader(batch_size=4)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    for k in range(1, 4):
        torch.manual_seed(2112)
        peft_config = IA3Config(task_type=TaskType.CAUSAL_LM, fan_in_fan_out=True)
        gpt2 = prepare_model_for_kbit_training(AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3B', load_in_8bit=True))
        fewshot = PeftGPT2Classifier(train.num_labels, peft_config, gpt2=gpt2)
        with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")

            for iteration in range(16):
                for istrain, (examples, labels) in interleave(train, dev):
                    with torch.no_grad():
                        inputs = []
                        target = torch.Tensor([ int(label == train.choices[1]) for label in labels ]).long().to(device)
        
                        for ex in examples:
                            embeddings = train.embed([ ex['title'] ] + 
                                                     [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu')
                            titles = [ f'"{ex["profile"][ind]["title"]}"' for ind in index.tolist() ]
                            concat_titles = ' and '.join(titles)
                            input = train.append_to_title(ex, concat_titles)
                            inputs.append(input)
        
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).float().mean().item()

                    fewloss = fewshot.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, fewloss, fewshotacc if istrain else None, fewshotacc if not istrain else None)

                printer.print()
                printer.autoprint = False

 
from Fork import SubProcess
with SubProcess() as process: process.parent or peft_gpt_baselines()

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.693       0.693         0.5         0.5           0           0        1.59
2                     0           0        2.36        4.03       0.375        0.25           0           0        2.61
4                     0           0        7.89          19       0.417         0.5        0.75        0.75         3.9
8                     0           0        8.22        8.46       0.321        0.25        0.75           0        7.85
16                    0           0        6.74        5.02       0.365       0.417       0.417        0.25        14.7
32                    0           0        4.35        1.96        0.49       0.615       0.417       0.417        28.7
64                    0           0        3.57        2.76       0.495         0.5       0.481       0.536        56.5
128                   0           0     