# Step 1: fine-tune LLM using top result from (fixed) ranker

In [2]:
def step_one(*, k):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        prompts = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = train.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        acc = (taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()
    
                    loss = taskllm.learn(prompts, target, augment=train.swap_refs) if istrain else None
                    printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                taskllm.save_pretrained(f'User_keq{k}_t5xxl_step1')

from Fork import SubProcess
with SubProcess() as process: process.parent or step_one(k=4)

n                  iter       since      4 loss       since       4 acc       since 4 acc (dev)       since      dt (s)
1                     0           0        0.72        0.72           0           0           0           0        4.03
2                     0           0       0.739       0.758        0.25         0.5           0           0        6.96
4                     0           0       0.737       0.732       0.333         0.5           1           1        11.1
8                     0           0       0.714       0.698       0.429         0.5           1           0          23
16                    0           0       0.732       0.753       0.385       0.333       0.667         0.5        42.9
32                    0           0       0.728       0.724       0.404       0.423       0.583         0.5        85.3
64                    0           0       0.719       0.709       0.431        0.46       0.462       0.357         167
128                   0           0     

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

$\gamma = 1$, hypercube

In [None]:
def learn_ranker(*, step1_iter, max_iteration, k):
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import PeftConfig, IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(prompts):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in prompts ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                 [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)', 'samps') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], 0
                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        rhatprompts = []
                        prompts = []
                        for rando in randos:
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = dev.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                       [ f"Extra: {t}" for t in titles ])
                            rhatprompts.append(rhatprompt)

                        rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=1)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps += len(actionind) 
                        guesses = taskllm.predict([ prompts[a] for a in actionind ], augment=dev.swap_refs).argmax(dim=1)
                        target = int(label == dev.choices[1])
                        rewards = (guesses == target).float().tolist()
                        greedyreward = rewards[0]
                        greedyrewards.append(greedyreward)

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device),
                                                    augment=reward_augment)
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None, nsamps if istrain else None)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5xxl_step2_iter{iteration}')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or learn_ranker(k=4, max_iteration=12, step1_iter=0)

n              iter (since)      4 loss (since)       4 acc (since) 4 acc (dev) (since)       samps (since)      dt
1                 0 (    0)       0.626 (0.626)         0.5 (  0.5)           0 (    0)          23 (   23)  19.4 s
2                 0 (    0)       0.658 ( 0.69)         0.5 (  0.5)           0 (    0)          18 (   13)  36.4 s
4                 0 (    0)       0.539 (0.421)        0.75 (    1)           0 (    0)        22.5 (   27)  1.41 m
8                 0 (    0)       0.598 (0.657)       0.688 (0.625)           0 (    0)          18 ( 13.5)   2.6 m
16                0 (    0)       0.601 (0.605)       0.719 ( 0.75)           0 (    0)        17.8 ( 17.5)  5.34 m
32                0 (    0)       0.607 (0.613)       0.703 (0.688)           0 (    0)        15.5 ( 13.2)  9.81 m
64                0 (    0)       0.608 (0.609)       0.688 (0.672)           0 (    0)        15.4 ( 15.4)  19.7 m
128               0 (    0)       0.586 (0.564)       0.719 ( 0.75)     