In [16]:
def baselines():
    from ProgressPrinter import ProgressPrinter
    from FewShot import ZeroShotClassifier, FewShotClassifier, PEFTFewShotClassifier
    from PersonalizedCitation import train_loader
    from peft import IA3Config, TaskType
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    zero = ZeroShotClassifier()
    few = FewShotClassifier()
    peft_config = IA3Config(task_type=TaskType.CAUSAL_LM, fan_in_fan_out=True)
    ft = PEFTFewShotClassifier(peft_config)

    with ProgressPrinter('0 shot acc', '1 shot acc', 'ft acc', 'ft loss') as printer:
        for inputs, profiles, answers in train:
            with torch.no_grad():
                multichoices = [ ( i, train.choices ) for i in inputs ] 
                zeroguesses = zero(multichoices)
                zeroreward = torch.Tensor([ float(guess == answer) for guess, answer in zip(zeroguesses, answers) ]).mean().item()

                shots = []
                for input_embedding, profile in zip(train.embed(inputs), profiles):
                    profile_embeddings = train.embeddings(profile)
                    indices = torch.topk(input_embedding @ profile_embeddings.T, dim=0, k=1).indices.to('cpu')
                    shots.append(train.stringify_articles(indices))

                fewguesses = few(multichoices, shots)[0]
                fewreward = torch.Tensor([ float(guess == answer) for guess, answer in zip(fewguesses, answers) ]).mean().item()

                ftguesses, ftguessindices = ft(multichoices, shots)
                ftreward = torch.Tensor([ float(guess == answer) for guess, answer in zip(ftguesses, answers) ])

            loss = ft.bandit_learn(multichoices, shots, ftguessindices, (2*ftreward-1).to(device))
                
            printer.addobs(zeroreward, fewreward, (1/2)*(ftreward.mean().item() + 1), loss)

from Fork import SubProcess
with SubProcess() as process: process.parent or baselines()

n          0 shot acc      since 1 shot acc      since     ft acc      since    ft loss      since     dt (s)
1                   1          1          1          1          1          1     -0.319     -0.319       1.82
2                 0.5          0        0.5          0       0.75        0.5     -0.156    0.00718       2.81
4               0.625       0.75      0.625       0.75      0.812      0.875    -0.0808    -0.0057       5.51
8               0.688       0.75      0.688       0.75      0.844      0.875    -0.0416   -0.00252       9.09
16              0.594        0.5      0.594        0.5      0.797       0.75    -0.0211  -0.000459       15.8
32                0.5      0.406        0.5      0.406       0.75      0.703    -0.0103   0.000473       30.6
64              0.453      0.406      0.453      0.406      0.727      0.703   -0.00466   0.000968       63.6
128             0.488      0.523      0.488      0.523      0.744      0.762   -0.00276  -0.000865        125
256       