# Step 1: fine-tune LLM using top result from (fixed) ranker

In [1]:
def step_one(*, k, max_iteration):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts, target = [], []
    
                    for ex, label in zip(examples, labels):
                        embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                  [ v['title'] 
                                                   for v in ex['profile']
                                                   if v['title'] != ex['title'] 
                                                 ])
                        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                        index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        prompt = train.append_to_title(ex, concat_titles)
                        prompts.append(prompt)
                        target.append(int(label == train.choices[1]))

                    target = torch.Tensor(target).long().to(device)
                    acc = (taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()

                loss = taskllm.learn(prompts, target, augment=train.swap_refs) if istrain else None
                printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}')

step_one(k=5, max_iteration=3)

n              iter (since)      5 loss (since)       5 acc (since) 5 acc (dev) (since)      dt
1                 0 (    0)       0.709 (0.709)           0 (    0)           0 (    0)  2.73 s
2                 0 (    0)       0.691 (0.673)         0.5 (    1)           0 (    0)  3.45 s
4                 0 (    0)        0.68 (0.669)         0.5 (  0.5)           0 (    0)  4.62 s
8                 0 (    0)       0.689 (0.698)       0.562 (0.625)           0 (    0)   6.8 s
16                0 (    0)       0.673 (0.656)       0.531 (  0.5)           0 (    0)  11.5 s
32                0 (    0)        0.68 (0.688)       0.625 (0.719)           0 (    0)  20.7 s
64                0 (    0)       0.671 (0.662)       0.617 (0.609)           0 (    0)  38.2 s
128               0 (    0)       0.645 (0.619)       0.652 (0.688)           0 (    0)  1.23 m
256               0 (    0)       0.615 (0.585)        0.68 (0.707)           0 (    0)   2.4 m
512               0 (    0)       0.588 

# Step 2: learn ranker

In [2]:
def learn_ranker(*, step1_iter, max_iteration, k):
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(prompts):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in prompts ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                 [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)', 'samps') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], 0

                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        rhatprompts = []
                        prompts = []
                        for rando in randos:
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            prompt = dev.append_to_title(ex, concat_titles)
                            prompts.append(prompt)
                            rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                       [ f"Extra: {t}" for t in titles ])
                            rhatprompts.append(rhatprompt)

                        rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=1)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps += len(actionind) 
                        guesses = taskllm.predict([ prompts[a] for a in actionind ], augment=dev.swap_refs).argmax(dim=1)
                        target = int(label == dev.choices[1])
                        rewards = (guesses == target).float().tolist()
                        greedyreward = rewards[0]
                        greedyrewards.append(greedyreward)

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device),
                                                    augment=reward_augment)
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None, nsamps if istrain else None)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}_hyper')

learn_ranker(k=5, max_iteration=12, step1_iter=1)

n              iter (since)      5 loss (since)       5 acc (since) 5 acc (dev) (since)       samps (since)      dt
1                 0 (    0)       0.629 (0.629)           1 (    1)           0 (    0)          38 (   38)  2.57 s
2                 0 (    0)       0.584 (0.538)           1 (    1)           0 (    0)        32.5 (   27)  4.55 s
4                 0 (    0)       0.528 (0.472)           1 (    1)           0 (    0)          22 ( 11.5)  7.31 s
8                 0 (    0)       0.576 (0.624)       0.875 ( 0.75)           0 (    0)        23.1 ( 24.2)  14.7 s
16                0 (    0)       0.592 (0.607)        0.75 (0.625)           0 (    0)        20.8 ( 18.4)  28.1 s
32                0 (    0)        0.59 (0.589)       0.734 (0.719)           0 (    0)        17.6 ( 14.5)  51.4 s
64                0 (    0)       0.603 (0.616)       0.695 (0.656)           0 (    0)        17.7 ( 17.8)   1.7 m
128               0 (    0)       0.608 (0.614)       0.711 (0.727)     

# Step 3: Prepare submission files

In [1]:
def prepare_submission_probensemble(*, nvoters, step2_iter, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=2)
    test = test_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')
    t5.load_adapter(f'User_keq{k}_t5base_step2_iter{step2_iter}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq{k}_t5base_step2_iter{step2_iter}')
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} nvoters {nvoters} ***')
    
    with ProgressPrinter(f'{k} acc (dev)') as printer:
        devgolds, testgolds = [], []
        
        for isdev, (examples, labels) in interleave(dev, test):
            greedyrewards = []
            for ex, label in zip(examples, labels):
                with torch.no_grad():
                    randos = randomized_similarity(ex, 64)
                    
                    rhatprompts = []
                    prompts = []
                    for rando in randos:
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        prompt = dev.append_to_title(ex, concat_titles)
                        prompts.append(prompt)
                        rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                   [ f"Extra: {t}" for t in titles ])
                        rhatprompts.append(rhatprompt)

                    rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                    voters = torch.topk(rhats, k=min(nvoters, len(rhats)), dim=0).indices.view(-1).to('cpu').tolist()
                    guess = taskllm.predict([ prompts[v] for v in voters ], augment=dev.swap_refs).logsumexp(dim=0, keepdim=True).argmax(dim=1)
                        
                    if isdev:
                        target = int(label == dev.choices[1])
                        reward = int(guess.item() == target)
                        greedyrewards.append(reward)

                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': "[2]" if guess else "[1]" })

            greedyacc = torch.Tensor(greedyrewards).float().mean().item() if isdev else None

            printer.addobs(greedyacc)

        printer.print()
        printer.autoprint = False

        for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
            with open(f'lamp1u_{wut}golds_t5base_keq{k}_step1_iter{step1_iter}_step2_iter{step2_iter}_pens_nvoters{nvoters}.json', 'w') as jsonfile:
                json.dump({ 'task': 'LaMP_1', 'golds': golds }, jsonfile)
            
from Fork import SubProcess
from Util import BadPipe
for nvoters in [1, 3, 5]:
    with BadPipe(), SubProcess() as process: process.parent or prepare_submission_probensemble(k=5, step1_iter=1, step2_iter='6_hyper', nvoters=nvoters)

*** step1_iter: 1 step2_iter: 6_hyper nvoters 1 ***
n       5 acc (dev) (since)      dt
1               0.5 (  0.5)  1.37 s
4               0.5 (  0.5)  4.27 s
8               0.5 (  0.5)  7.63 s
16            0.625 ( 0.75)  14.2 s
32            0.688 ( 0.75)  27.6 s
64            0.734 (0.781)  56.3 s
128           0.734 (0.734)  1.86 m
256           0.719 (0.703)  3.68 m
512           0.717 (0.715)  7.43 m
1024          0.721 (0.725)  14.9 m
2048          0.726 (0.731)    30 m
2500          0.728 (0.739)  36.6 m
*** step1_iter: 1 step2_iter: 6_hyper nvoters 3 ***
n       5 acc (dev) (since)      dt
1               0.5 (  0.5)  1.39 s
4               0.5 (  0.5)  4.36 s
8               0.5 (  0.5)  7.78 s
16            0.625 ( 0.75)  14.6 s
32            0.688 ( 0.75)  28.3 s
64            0.719 ( 0.75)  57.4 s
128           0.727 (0.734)   1.9 m
256           0.727 (0.727)  3.74 m
512           0.725 (0.723)  7.56 m
1024          0.727 (0.729)  15.2 m
2048          0.734 (0.742)  30.

> Hi,
> 
> This is the result of your most recent submission to the LaMP benchmark.
>
>
> {"accuracy": 0.7436}
>
> Best,
> Alireza