# Step 1: fine-tune LLM using top result from (fixed) ranker

In [5]:
def step_one(*, k, max_iteration):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts = []
                    target = []
    
                    for ex, label in zip(examples, labels):
                        embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                  [ v['title'] 
                                                   for v in ex['profile']
                                                   if v['title'] != ex['title'] 
                                                 ])
                        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                        index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        prompt = train.append_to_title(ex, concat_titles)
                        prompts.append(prompt)
                        target.append(int(label == train.choices[1]))

                    target = torch.Tensor(target).long().to(device)
                    acc = (taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()

                loss = taskllm.learn(prompts, target, augment=train.swap_refs) if istrain else None
                printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5xxl_step1_iter{iteration}')

from Util import Filter
import sys
sys.__stderr__ = sys.__stderr__ if type(sys.__stderr__) == Filter else Filter(sys.__stderr__, r'Bad pipe message') 
from Fork import SubProcess
with SubProcess() as process: process.parent or step_one(k=5, max_iteration=3)

n              iter (since)      5 loss (since)       5 acc (since) 5 acc (dev) (since)      dt
1             0.000 (0.000)       0.670 (0.670)       0.500 (0.500)       0.000 (0.000)  4.05 s
2             0.000 (0.000)       0.748 (0.825)       0.500 (0.500)       0.000 (0.000)  7.01 s
4             0.000 (0.000)       0.650 (0.552)       0.750 (1.000)       0.000 (0.000)  13.2 s
8             0.000 (0.000)       0.686 (0.722)       0.562 (0.375)       0.000 (0.000)    25 s
16            0.000 (0.000)       0.693 (0.699)       0.531 (0.500)       0.000 (0.000)  49.2 s
32            0.000 (0.000)       0.682 (0.672)       0.531 (0.531)       0.000 (0.000)  1.65 m
64            0.000 (0.000)       0.677 (0.672)       0.570 (0.609)       0.000 (0.000)   3.3 m
128           0.000 (0.000)       0.659 (0.641)       0.629 (0.688)       0.000 (0.000)  6.56 m
256           0.000 (0.000)       0.624 (0.588)       0.689 (0.750)       0.000 (0.000)  13.1 m
512           0.000 (0.000)       0.616 

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

$\gamma$ = 1, hypercube

In [None]:
def learn_ranker(*, step1_iter, max_iteration, k): # NB: nsamps is off by a factor of 2
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import PeftConfig, IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(prompts):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in prompts ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                 [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)', 'samps') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], 0
                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        rhatprompts = []
                        prompts = []
                        for rando in randos:
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = dev.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                       [ f"Extra: {t}" for t in titles ])
                            rhatprompts.append(rhatprompt)

                        rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=1)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps += len(actionind) 
                        guesses = taskllm.predict([ prompts[a] for a in actionind ], augment=dev.swap_refs).argmax(dim=1)
                        target = int(label == dev.choices[1])
                        rewards = (guesses == target).float().tolist()
                        greedyreward = rewards[0]
                        greedyrewards.append(greedyreward)

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device),
                                                    augment=reward_augment)
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None, nsamps if istrain else None)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5xxl_step2_iter{iteration}')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or learn_ranker(k=5, max_iteration=12, step1_iter=0)

n              iter (since)      5 loss (since)       5 acc (since) 5 acc (dev) (since)       samps (since)      dt
1             0.000 (0.000)       0.533 (0.533)       1.000 (1.000)       0.000 (0.000)      30.000 (30.000)  24.5 s
2             0.000 (0.000)       0.552 (0.571)       1.000 (1.000)       0.000 (0.000)      20.000 (10.000)  41.7 s
4             0.000 (0.000)       0.504 (0.456)       1.000 (1.000)       0.000 (0.000)      16.250 (12.500)  1.31 m
8             0.000 (0.000)       0.529 (0.554)       0.812 (0.625)       0.000 (0.000)      14.375 (12.500)  2.58 m
16            0.000 (0.000)       0.577 (0.624)       0.719 (0.625)       0.000 (0.000)      17.375 (20.375)  5.87 m
32            0.000 (0.000)       0.632 (0.688)       0.688 (0.656)       0.000 (0.000)      16.156 (14.938)  11.3 m
64            0.000 (0.000)       0.622 (0.611)       0.711 (0.734)       0.000 (0.000)      16.609 (17.062)  22.6 m
128           0.000 (0.000)       0.584 (0.545)       0.758 (0.80