# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def step_one(*, k, max_iteration):
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                with torch.no_grad():
                    inputs = []
                    target = []
    
                    for ex, label in zip(examples, labels):
                        embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                  [ v['title'] 
                                                   for v in ex['profile']
                                                   if v['title'] != ex['title'] 
                                                 ])
                        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                        index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                        for n, oneind in enumerate(index.tolist()):
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            input = train.append_to_title(ex, concat_titles)
                            inputs.append(input)
                            target.append(int(label == train.choices[1]))

                    target = torch.Tensor(target).long().to(device)
                    guess = taskllm.predict(inputs)
                    assert False, guess.shape
                    mae = torch.abs(guess - target).mean().item()
    
                    loss = taskllm.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, loss, mae if istrain else None, mae if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or step_one(k=1, max_iteration=3)

 96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎    | 2.92G/3.04G [04:20<00:09, 11.9MiB/s]

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(rank):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

## No benefit from multiple passes

In [None]:
# two training passes
def learn_ranker(rank):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        max_iteration = 3
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration + 1 < max_iteration or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            
        rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

In [None]:
# three training passes
def learn_ranker(rank):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        max_iteration = 4
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration + 1 < max_iteration or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            
        rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

## No benefit from larger rank

In [None]:
def learn_ranker(rank, max_iteration):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration + 1 < max_iteration or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            
        rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(12, 28, 4):
    with SubProcess() as process: process.parent or learn_ranker(rank, 3)

# Step 3: Prepare leaderboard submission files

In [None]:
def prepare_submission(*, rank):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=2)
    test = test_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter('User_keq1_t5base_step1', 'taskllm')
    t5.load_adapter(f'User_keq1_t5base_step2_rankeq{rank}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq1_t5base_step2_rankeq{rank}')
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter(f'{rank} acc (dev)') as printer:
        devgolds, testgolds = [], []
        
        for isdev, (examples, labels) in interleave(dev, test):
            greedyrewards = []
            for ex, label in zip(examples, labels):
                with torch.no_grad():
                    embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                              [ v['title'] 
                                               for v in ex['profile']
                                               if v['title'] != ex['title'] 
                                             ])
                    scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                    index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                    prompts = []
                    rhatprompts = []
                    for n, oneind in enumerate(index.tolist()):
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        input = dev.append_to_title(ex, concat_titles)
                        prompts.append(input)
                        rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                        rhatprompts.append(rhatprompt)

                    rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                    greedy = torch.argmax(rhats, dim=0).item()
                    guess = taskllm.predict([ prompts[greedy] ], augment=dev.swap_refs).argmax(dim=1)
                    if isdev:
                        target = int(label == dev.choices[1])
                        reward = int(guess.item() == target)
                        greedyrewards.append(reward)

                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': "[2]" if guess else "[1]" })

            greedyacc = torch.Tensor(greedyrewards).float().mean().item() if isdev else None

            printer.addobs(greedyacc)

        printer.print()
        printer.autoprint = False

        for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
            with open(f'lamp1u_{wut}golds_rankeq{rank}.json', 'w') as jsonfile:
                json.dump({ 'task': 'LaMP_1', 'golds': golds }, jsonfile)
            
from Util import Filter
import sys
sys.stderr = sys.stderr if type(sys.stderr) == Filter else Filter(sys.stderr, r'Bad pipe message') 
from Fork import SubProcess
with SubProcess() as process: process.parent or prepare_submission(rank=8)