# Step 1: fine-tune LLM using top result from (fixed) ranker

In [1]:
def step_one():
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    k = 1
    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        acc = (taskllm.predict(inputs, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()
    
                    loss = taskllm.learn(inputs, target, augment=train.swap_refs) if istrain else None
                    printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                taskllm.save_pretrained('User_keq1_t5xxl_step1')

from Fork import SubProcess
with SubProcess() as process: process.parent or step_one()

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.845       0.845           0           0           0           0        9.54
2                     0           0       0.806       0.767        0.25         0.5           0           0          15
4                     0           0       0.799       0.785       0.167           0           1           1        22.5
8                     0           0       0.741       0.698         0.5        0.75           1           0        44.2
16                    0           0         0.7       0.652       0.692       0.917           1           1        80.9
32                    0           0       0.703       0.706       0.538       0.385        0.75         0.5         158
64                    0           0       0.698       0.692       0.559        0.58       0.577       0.429         310
128                   0           0     

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [2]:
def learn_ranker(rank):
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter('User_keq1_t5xxl_step1', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        warnings.filterwarnings("ignore", message=".*If you want to save 8-bit models.*")
        
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)

                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)

                        if True: # A100 :)
                            loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        else: # memory issues
                            inner_batch_size = 2
                            loss = sum(len(x) * rewardpredictor.learn(x, rtensor, augment=reward_augment) 
                                       for x, r in zip(chunked(rhatprompts, inner_batch_size), chunked(rewards.tolist(), inner_batch_size))
                                       for rtensor in (torch.Tensor(r).to(device),)
                                      ) / len(rhatprompts) if istrain else None
                            
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False

        rewardpredictor.save_pretrained(f'User_keq1_t5xxl_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0       0.709       0.709           0           0           0           0        8.71
2                     0           0       0.702       0.695         0.5           1           0           0          16
4                     0           0         0.7       0.698         0.5         0.5           0           0        31.1
8                     0           0       0.607       0.482       0.714           1           1           1        57.6
16                    0           0       0.677       0.739       0.667       0.625           1           0         121
32                    0           0       0.632       0.583       0.793       0.929           1           1         236
64                    0           0        0.58       0.527       0.772        0.75       0.857        0.75         469
128                   0           0     