# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def step_one(*, k, max_iteration):
    import evaluate
    from TaskLLM import TaskLLM
    from PersonalizedNews import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=8, augment=True)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    rouge_metric = evaluate.load('rouge')
    
    with ProgressPrinter('iter', f'{k} loss', f'{k} rouge1', f'{k} rouge1 (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts = []
    
                    for ex in examples:
                        embeddings = dev.embed( [ ex['article'] ] + 
                                                [ v['text']
                                                  for v in ex['profile']
                                                ])
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu').tolist()
                        profile_examples = [ ex['profile'][ind] for ind in index ]
                        prompt = dev.prepend_to_prompt(ex, profile_examples)
                        prompts.append(prompt)

                    guesses = taskllm.generate(prompts)
                    scores = rouge_metric.compute(predictions=guesses, references=labels)['rouge1']
    
                loss = taskllm.learn(prompts, labels) if istrain else None
                printer.addobs(iteration, loss, scores if istrain else None, scores if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}')

step_one(k=4, max_iteration=5)

n                 iter (since)         4 loss (since)       4 rouge1 (since) 4 rouge1 (dev) (since)      dt
1                0.000 (0.000)          2.815 (2.815)          0.172 (0.172)          0.000 (0.000)  5.29 s
2                0.000 (0.000)          3.177 (3.540)          0.126 (0.081)          0.000 (0.000)  9.11 s
4                0.000 (0.000)          3.121 (3.064)          0.141 (0.156)          0.000 (0.000)  19.3 s
8                0.000 (0.000)          3.138 (3.155)          0.148 (0.155)          0.000 (0.000)  32.1 s
16               0.000 (0.000)          3.148 (3.158)          0.159 (0.170)          0.000 (0.000)   1.1 m
32               0.000 (0.000)          3.141 (3.134)          0.155 (0.150)          0.000 (0.000)  2.23 m
64               0.000 (0.000)          3.065 (2.990)          0.168 (0.182)          0.000 (0.000)  4.73 m
128              0.000 (0.000)          3.041 (3.017)          0.170 (0.171)          0.000 (0.000)  9.46 m
256              0.000 (0.00

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(*, step1_iter, max_iteration, k):
    import evaluate
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNews import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=8, augment=True)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['article'] ] + 
                                [ v['text']
                                  for v in ex['profile']
                                ])
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    rouge_metric = evaluate.load('rouge')

    with ProgressPrinter('iter', f'{k} loss', f'{k} rouge1', f'{k} rouge1 (dev)', 'samps') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], []

                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        prompts, rhatprompts = [], []
                        for rando in randos:
                            profile_examples = [ ex['profile'][ind] for ind in rando ]
                            prompt = dev.prepend_to_prompt(ex, profile_examples)
                            prompts.append(prompt)
                            rhatprompts.append(prompt)

                        rhats = rewardpredictor.predict(rhatprompts)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=10)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps.append(len(actionind))
                        guesses = taskllm.generate([ prompts[a] for a in actionind ])
                        rewards = rouge_metric.compute(predictions=guesses,
                                                       references=[label]*len(guesses),
                                                       use_aggregator=False)['rouge1']
                        greedyrewards.append(rewards[0])

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device)
                                                 )
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyreward = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None
                nsamps = torch.Tensor(nsamps).float().mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyreward if istrain else None, greedyreward if not istrain else None, nsamps)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}')

learn_ranker(k=4, max_iteration=12, step1_iter='1_1024')

In [1]:
def learn_ranker(*, step1_iter, max_iteration, k):
    import evaluate
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNews import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=8, augment=True)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['article'] ] + 
                                [ v['text']
                                  for v in ex['profile']
                                ])
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    rouge_metric = evaluate.load('rouge')

    with ProgressPrinter('iter', f'{k} loss', f'{k} rouge1', f'{k} rouge1 (dev)', 'samps') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedyrewards, allloss, nsamps = [], [], []

                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        prompts, rhatprompts = [], []
                        for rando in randos:
                            profile_examples = [ ex['profile'][ind] for ind in rando ]
                            prompt = dev.prepend_to_prompt(ex, profile_examples)
                            prompts.append(prompt)
                            rhatprompts.append(prompt)

                        rhats = rewardpredictor.predict(rhatprompts)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=10)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps.append(len(actionind))
                        guesses = taskllm.generate([ prompts[a] for a in actionind ])
                        rewards = rouge_metric.compute(predictions=guesses,
                                                       references=[label]*len(guesses),
                                                       use_aggregator=False)['rouge1']
                        greedyrewards.append(rewards[0])

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device)
                                                 )
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedyreward = torch.Tensor(greedyrewards).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None
                nsamps = torch.Tensor(nsamps).float().mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyreward if istrain else None, greedyreward if not istrain else None, nsamps)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}')

learn_ranker(k=4, max_iteration=12, step1_iter=1)

n                 iter (since)         4 loss (since)       4 rouge1 (since) 4 rouge1 (dev) (since)          samps (since)      dt
1                0.000 (0.000)          0.421 (0.421)          0.170 (0.170)          0.000 (0.000)         46.188 (46.188)  1.11 m
2                0.000 (0.000)          0.378 (0.335)          0.147 (0.124)          0.000 (0.000)         33.719 (21.250)  1.76 m
4                0.000 (0.000)          0.426 (0.474)          0.187 (0.227)          0.000 (0.000)         25.375 (17.031)  2.95 m
8                0.000 (0.000)          0.452 (0.479)          0.187 (0.188)          0.000 (0.000)         19.922 (14.469)  4.96 m
16               0.000 (0.000)          0.473 (0.494)          0.194 (0.201)          0.000 (0.000)         15.938 (11.953)  8.88 m
32               0.000 (0.000)          0.471 (0.468)          0.189 (0.184)          0.000 (0.000)         12.684 (9.430)    16 m
64               0.000 (0.000)          0.476 (0.481)          0.193 (0.197)  