# Case 1: Fixed Ranker, no-fine tuning

In [1]:
def fixed_ranker_noft(*, k):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=1)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    taskllm = TaskLLM(t5=t5)

    with ProgressPrinter(f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for examples, labels in dev:
            with torch.no_grad():
                prompts = []
                target = []

                for ex, label in zip(examples, labels):
                    embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                            [ v['title'] 
                                              for v in ex['profile']
                                              if v['title'] != ex['title'] 
                                            ])
                    scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                    index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                    titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                    concat_titles = ' and '.join([f'"{v}"' for v in titles])
                    prompt = dev.append_to_title(ex, concat_titles)
                    prompts.append(prompt)
                    target.append(int(label == dev.choices[1]))

                target = torch.Tensor(target).long().to(device)
                acc = (taskllm.predict(prompts, augment=dev.swap_refs).argmax(dim=1) == target).float().mean().item()

            printer.addobs(acc)

fixed_ranker_noft(k=5)

n       5 acc (dev) (since)      dt
1             1.000 (1.000)  3.41 s
2             0.500 (0.000)  5.03 s
4             0.250 (0.000)  8.82 s
8             0.500 (0.750)  16.2 s
16            0.375 (0.250)  30.3 s
32            0.469 (0.562)  58.1 s
64            0.578 (0.688)  1.93 m
128           0.523 (0.469)  3.83 m
256           0.520 (0.516)  7.61 m
512           0.520 (0.520)  15.2 m
1024          0.519 (0.518)  30.5 m
2048          0.527 (0.536)  1.06 h
2500          0.534 (0.564)  1.28 h


# Case 2: Learned Ranker, no fine-tuning

In [None]:
def learn_ranker(*, max_iteration, k):
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import PeftConfig, IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5)
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(prompts):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in prompts ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                 [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)', 'samps') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], []
                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        rhatprompts = []
                        prompts = []
                        for rando in randos:
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = dev.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                       [ f"Extra: {t}" for t in titles ])
                            rhatprompts.append(rhatprompt)

                        rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=1)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps.append(len(actionind)) 
                        guesses = taskllm.predict([ prompts[a] for a in actionind ], augment=dev.swap_refs).argmax(dim=1)
                        target = int(label == dev.choices[1])
                        rewards = (guesses == target).float().tolist()
                        greedyreward = rewards[0]
                        greedyrewards.append(greedyreward)

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device),
                                                    augment=reward_augment)
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None
                nsamples = torch.Tensor(nsamps).mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None, nsamples)

            printer.print()
            printer.autoprint = False

learn_ranker(k=5, max_iteration=12)

n              iter (since)      5 loss (since)       5 acc (since) 5 acc (dev) (since)       samps (since)      dt
1             0.000 (0.000)       0.693 (0.693)       1.000 (1.000)       0.000 (0.000)      22.000 (22.000)  1.07 m
2             0.000 (0.000)       0.694 (0.696)       1.000 (1.000)       0.000 (0.000)      36.000 (50.000)  3.34 m
4             0.000 (0.000)       0.712 (0.729)       0.750 (0.500)       0.000 (0.000)      24.375 (12.750)  5.06 m
8             0.000 (0.000)       0.681 (0.650)       0.625 (0.500)       0.000 (0.000)      21.500 (18.625)  9.58 m
16            0.000 (0.000)       0.684 (0.686)       0.656 (0.688)       0.000 (0.000)      18.719 (15.938)  18.1 m
32            0.000 (0.000)       0.689 (0.693)       0.625 (0.594)       0.000 (0.000)      25.047 (31.375)  43.7 m
64            0.000 (0.000)       0.689 (0.690)       0.547 (0.469)       0.000 (0.000)      26.492 (27.938)   1.5 h
128           0.000 (0.000)       0.688 (0.687)       0.543 (0.53