# flan-t5-base 

peft ia3, top-k titles based upon `'\n\n'.join([ ex['ref1'], ex['ref2'] ])` ... not as good

peft ia3, top-k titles based upon max similarity with ref1 and ref2 ... best result yet

In [1]:
# alpha=5_000, full finetune
def peft_t5_baselines(k):
    from MegaT5 import T5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    import parameterfree
    opt_factory = lambda params: parameterfree.COCOB(params, alpha=5_000)
    fewshot = T5Classifier(train.num_labels, t5=t5, opt_factory=opt_factory)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        target = torch.Tensor([ int(label == train.choices[1]) for label in labels ]).long().to(device)
        
                        for ex in examples:
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            titles = [ f'"{ex["profile"][ind]["title"]}"' for ind in index.tolist() ]
                            concat_titles = ' and '.join(titles)
                            input = train.append_to_title(ex, concat_titles)
                            inputs.append(input)
        
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).float().mean().item()
    
                    fewloss = fewshot.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, fewloss, fewshotacc if istrain else None, fewshotacc if not istrain else None)

            printer.print()
            printer.autoprint = False
 
from Fork import SubProcess
for k in range(0, 5):
    with SubProcess() as process: process.parent or peft_t5_baselines(k)

n                  iter       since      0 loss       since       0 acc       since 0 acc (dev)       since      dt (s)
1                     0           0       0.674       0.674           1           1           0           0        1.28
2                     0           0         3.1        5.52         0.5           0           0           0        1.82
4                     0           0        2.57        1.53       0.333           0           0           0        2.59
8                     0           0        1.52       0.731         0.5       0.625           0           0        4.69
16                    0           0        1.25        0.93       0.423       0.333         0.5        0.75        8.24
32                    0           0       0.989       0.729       0.462         0.5       0.583       0.667        16.4
64                    0           0       0.856       0.719       0.529         0.6       0.731       0.857        31.4
128                   0           0     

# flan-t5-xl (8bit)

peft ia3, top-k titles based upon max similarity with ref1 and ref2

better than flan-t5-base.  maybe flan-t5-xxl is even better. unfortunately flan-t5-xxl doesn't fit on a T4 in 8bit, and 4bit doesn't seem to work.

In [None]:
def peft_t5_baselines(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xl', load_in_8bit=True))
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        target = torch.Tensor([ int(label == train.choices[1]) for label in labels ]).long().to(device)
        
                        for ex in examples:
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            titles = [ f'"{ex["profile"][ind]["title"]}"' for ind in index.tolist() ]
                            concat_titles = ' and '.join(titles)
                            input = train.append_to_title(ex, concat_titles)
                            inputs.append(input)
        
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).float().mean().item()
    
                    fewloss = fewshot.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, fewloss, fewshotacc if istrain else None, fewshotacc if not istrain else None)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(0, 5):
    with SubProcess() as process: process.parent or peft_t5_baselines(k)

n                  iter       since      0 loss       since       0 acc       since 0 acc (dev)       since      dt (s)
1                     0           0       0.775       0.775           0           0           0           0        2.43
2                     0           0        2.23        3.68           0           0           0           0        3.93
4                     0           0        1.88        1.19           0           0           0           0        6.04
8                     0           0        1.24       0.766       0.214       0.375           0           0          12
16                    0           0        1.09        0.92       0.231        0.25         0.5        0.75          22
32                    0           0       0.911       0.728       0.365         0.5       0.583       0.667        43.9
64                    0           0        0.81       0.705       0.471        0.58       0.731       0.857          86
128                   0           0     

Bad pipe message: %s [b'(\x96\xb4\x04\xa6V4\x15\x1f\x82\x8e\xfeH64\xe7G: 0\xa27\xa1{\xe8\x16X\xe7\x7f\xa5\x8b\xfe=\x05\xe2\x95\xd4\xd8\x06\xe15\xc8\x8b\xa8\x05\x05\x8c-\x0ck\xb0\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 \xa9x!~\x9a\xc6\x96Pg']
Bad pipe message: %s [b'\xc9\xd0Zw\xac\xad\x0bt\x7f2FS\xb8\xd1\x96wS\x02\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0', b"\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0

6090                  0           0       0.546       0.544        0.72       0.716       0.706       0.731     8.6e+03
7339               0.17           1       0.546           0        0.72           0       0.716       0.727    9.41e+03
n                  iter       since      3 loss       since       3 acc       since 3 acc (dev)       since      dt (s)
1                     0           0       0.642       0.642         0.5         0.5           0           0        2.53
2                     0           0        2.32        3.99        0.25           0           0           0        4.15
4                     0           0        1.95        1.21       0.167           0           0           0        6.48
8                     0           0        1.27       0.758       0.429       0.625           0           0          13
16                    0           0        1.12       0.952       0.385       0.333         0.5        0.75        24.1
32                    0           0     