# Step 1: fine-tune LLM using top result from (fixed) ranker

In [2]:
def step_one(*, k):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        prompts = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = train.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        acc = (taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()
    
                    loss = taskllm.learn(prompts, target, augment=train.swap_refs) if istrain else None
                    printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                taskllm.save_pretrained(f'User_keq{k}_t5base_step1')

from Fork import SubProcess
with SubProcess() as process: process.parent or step_one(k=4)

n                  iter       since      4 loss       since       4 acc       since 4 acc (dev)       since      dt (s)
1                     0           0       0.711       0.711           0           0           0           0         1.2
2                     0           0       0.695       0.678         0.5           1           0           0        1.69
4                     0           0       0.689       0.677       0.667           1           1           1        2.46
8                     0           0        0.69       0.691       0.643       0.625           1           0        4.38
16                    0           0       0.675       0.658       0.654       0.667       0.667         0.5        7.73
32                    0           0       0.681       0.686       0.654       0.654       0.833           1        15.5
64                    0           0        0.67       0.659       0.686        0.72       0.808       0.786        29.7
128                   0           0     

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(*, max_iteration, k):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretGreedyDoubleSampler
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq4_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(prompts):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in prompts ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                  [ v['title'] 
                                   for v in ex['profile']
                                   if v['title'] != ex['title'] 
                                 ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration + 1 < max_iteration or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            randos = randomized_similarity(ex, 64)
                            
                            rhatprompts = []
                            prompts = []
                            for rando in randos:
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                prompt = train.append_to_title(ex, concat_titles)
                                prompts.append(prompt)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                           [ f"Extra: {t}" for t in titles ])
                                rhatprompts.append(rhatprompt)

                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            if len(rhats) > 1:
                                explore, exploit = SimpleRegretGreedyDoubleSampler(rhats.view(1, -1), gamma=10)
                                actionind = [exploit.item(), explore.item()]
                            else:
                                actionind = [0]

                            guesses = taskllm.predict([ prompts[a] for a in actionind ], augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            greedyreward = rewards[0, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn([ rhatprompts[a] for a in actionind ], rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration + 1 < max_iteration:
                rewardpredictor.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}')

from Fork import SubProcess
with SubProcess() as process: process.parent or learn_ranker(k=4, max_iteration=8)

n                  iter       since      4 loss       since       4 acc       since 4 acc (dev)       since      dt (s)
1                     0           0       0.694       0.694           0           0           0           0        2.92
2                     0           0       0.727        0.76           0           0           0           0        6.21
4                     0           0       0.678       0.629         0.5           1           0           0        12.8
8                     0           0       0.602         0.5       0.714           1           1           1        24.3
16                    0           0       0.661       0.713         0.6         0.5           1           0        50.5


Bad pipe message: %s [b"v\x1b)\\\xaa\xf4\xe8}'\xa7\xf7\xb1xaw\x07\xd8. x8\x10\xb6\xe4\xe62\xd5\x1b1\xd4f\x196Z"]
Bad pipe message: %s [b"?\xab\xd4\xf6A\xc9\x1c\xd9PSF\xcd\xeeS%Bn1\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00\xc0\x00<\x00\xba\x005\x00\x84\x00/\x00\x96\x00A\x00\x05\x00\n\x00\xff\x01\x00\x00j\x00\x00\x00"]
Bad pipe message: %s [b"\xcb\xd5:\xa7-o\x14\xbc\xa3\x80\x8b\xe8\xd8\xdeg\xf1\xf6\x96\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0

32                    0           0       0.572       0.477       0.724       0.857       0.667         0.5        97.5
64                    0           0       0.552       0.531       0.737        0.75       0.857           1         198
128                   0           0       0.545       0.538       0.772       0.807       0.786       0.714         393
256                   0           0       0.558       0.571       0.744       0.717       0.793         0.8         792
512                   0           0       0.579       0.601       0.736       0.727       0.655       0.517    1.59e+03
1024                  0           0       0.593       0.607       0.729       0.722       0.718        0.78    3.18e+03
2048                  0           0       0.596         0.6        0.72       0.711       0.667       0.615    6.37e+03


Bad pipe message: %s [b"\x9a\xb9\xfb\x87\xa9YP\xbc'\x97", b"\xfb4\xad\n\xa5q \xdb\x08\xea\xd2\xe5\x0c\xa5'\xd2\xb3\xc95\x80;z\xd6<]h\xe3\x1b0\x16\x0c\x14\xa9\xe7\xae6m\xa03\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00"]
Bad pipe message: %s [b'\xcdU\xc3\x0b!\xc7\xde\xc5a\xdf\x811\xf3\xfcI\xc8p\xc4 \xce_[l\xe0\xf8=p\x82fBE\xae\xee\x1c\x0e\xe3a\xe9\xf0\x90{\n\xca\x9d4\xf8\xc5\xcd\xfdU\xc2\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03']
Bad pipe message: %s [b"\xe0W\x97\xce\x95z\xf2\x93e\x0426NCH\xf4c\x90\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009

4096                  0           0       0.594       0.592       0.723       0.725        0.66       0.654    1.28e+04
8192                  0           0       0.596       0.597       0.717       0.712       0.688       0.716    2.54e+04
10931                 0           0       0.597       0.601       0.714       0.705       0.693       0.706    3.39e+04
21862               0.5           1        0.59       0.582       0.724       0.735       0.693       0.693    6.78e+04
