# Step 1: fine-tune LLM using top result from (fixed) ranker

In [1]:
def step_one(k):
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        acc = (taskllm.predict(inputs, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()
    
                    loss = taskllm.learn(inputs, target, augment=train.swap_refs) if istrain else None
                    printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                taskllm.save_pretrained(f'User_keq{k}_t5base_step1')

step_one(1)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.708       0.708           0           0           0           0        1.21
2                     0           0        0.69       0.673         0.5           1           0           0        1.74
4                     0           0        0.68       0.661       0.667           1           1           1        2.53
8                     0           0       0.693       0.703       0.571         0.5           1           0        4.53
16                    0           0       0.696         0.7       0.577       0.583         0.5        0.25        8.03
32                    0           0       0.694       0.692       0.538         0.5        0.75           1        16.1
64                    0           0        0.69       0.686       0.588        0.64       0.654       0.571          31
128                   0           0     

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [1]:
def learn_ranker(rank):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

learn_ranker(rank=8)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0       0.718       0.718           0           0           0           0        1.85
2                     0           0       0.708       0.699           0           0           0           0        2.94
4                     0           0       0.709        0.71         0.5           1           0           0        5.16
8                     0           0       0.686       0.656       0.571       0.667           1           1         9.1
16                    0           0       0.706       0.724       0.467       0.375           1           0        19.1
32                    0           0       0.638       0.565       0.621       0.786           1           1        36.9
64                    0           0       0.637       0.636       0.667       0.714       0.714         0.5        74.1
128                   0           0     

# Step 3: Prepare leaderboard submission files

In [23]:
def prepare_submission(*, rank):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=2)
    test = test_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')
    t5.load_adapter(f'User_keq1_t5base_step2_rankeq{rank}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq1_t5base_step2_rankeq{rank}')
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter(f'{rank} acc (dev)') as printer:
        devgolds, testgolds = [], []
        
        for isdev, (examples, labels) in interleave(dev, test):
            greedyrewards = []
            for ex, label in zip(examples, labels):
                with torch.no_grad():
                    embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                              [ v['title'] 
                                               for v in ex['profile']
                                               if v['title'] != ex['title'] 
                                             ])
                    scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                    index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                    prompts = []
                    rhatprompts = []
                    for n, oneind in enumerate(index.tolist()):
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        input = dev.append_to_title(ex, concat_titles)
                        prompts.append(input)
                        rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                        rhatprompts.append(rhatprompt)

                    rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                    greedy = torch.argmax(rhats, dim=0).item()
                    guess = taskllm.predict([ prompts[greedy] ], augment=dev.swap_refs).argmax(dim=1)
                    if isdev:
                        target = int(label == dev.choices[1])
                        reward = int(guess.item() == target)
                        greedyrewards.append(reward)

                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': "[2]" if guess else "[1]" })

            greedyacc = torch.Tensor(greedyrewards).float().mean().item() if isdev else None

            printer.addobs(greedyacc)

        printer.print()
        printer.autoprint = False

        for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
            with open(f'lamp1u_{wut}golds_rankeq{rank}.json', 'w') as jsonfile:
                json.dump({ 'task': 'LaMP_1', 'golds': golds }, jsonfile)
            
prepare_submission(rank=8)

n           8 acc (dev)       since      dt (s)
1                   0.5         0.5        1.08
4                  0.75           1         2.8
8                 0.625         0.5        5.17
16                0.688        0.75        9.38
32                0.688       0.688        18.6
64                0.703       0.719        38.5
128               0.688       0.672        77.8
256               0.703       0.719         154
512               0.686       0.668         318
1024              0.704       0.723         642
2048              0.712       0.721    1.28e+03
2500              0.718       0.741    1.57e+03
