flan-t5-base, peft ia3, top-k titles based upon `'\n\n'.join([ ex['ref1'], ex['ref2'] ])`

In [None]:
def peft_t5_baselines(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=4)
    dev = dev_loader(batch_size=4)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                with torch.no_grad():
                    inputs = []
                    target = torch.Tensor([ int(label == train.choices[1]) for label in labels ]).long().to(device)
    
                    for ex in examples:
                        embeddings = train.embed([ '\n\n'.join([ ex['ref1'], ex['ref2'] ]) ] + 
                                                 [ v['title'] 
                                                   for v in ex['profile']
                                                   if v['title'] != ex['title'] 
                                                 ])
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu')
                        titles = [ f'"{ex["profile"][ind]["title"]}"' for ind in index.tolist() ]
                        concat_titles = ' and '.join(titles)
                        input = train.append_to_title(ex, concat_titles)
                        inputs.append(input)
    
                    fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).float().mean().item()

                fewloss = fewshot.learn(inputs, target) if istrain else None
                printer.addobs(iteration, fewloss, fewshotacc if istrain else None, fewshotacc if not istrain else None)

            printer.print()
            printer.autoprint = False

 
from Fork import SubProcess
for k in range(0, 4):
    with SubProcess() as process: process.parent or peft_t5_baselines(k)

n                  iter       since      0 loss       since       0 acc       since 0 acc (dev)       since      dt (s)
1                     0           0       0.697       0.697         0.5         0.5           0           0       0.865
2                     0           0       0.775       0.853       0.375        0.25           0           0        1.15
4                     0           0       0.866        1.05       0.417         0.5        0.25        0.25        1.56
8                     0           0       0.829       0.801         0.5       0.562        0.25           0        2.61
16                    0           0       0.784       0.732       0.481       0.458        0.25        0.25        4.64
32                    0           0       0.741       0.697       0.519       0.558         0.5        0.75        8.67
64                    0           0       0.737       0.732        0.51         0.5       0.462       0.429        16.7
128                   0           0     