# Step 1 Dev Set Labels

In [None]:
def step1_dev_set_labels(*, step1_iter, k):
    import evaluate
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNews import dev_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=8)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_model_id = f'User_keq{k}_t5base_step1_iter{step1_iter}'
    t5.load_adapter(taskllm_model_id, 'raw_taskllm')
    t5.load_adapter(taskllm_model_id, 'ema_taskllm')
    
    taskllm = TaskLLM(t5=t5, adapter_suffix="taskllm")
    rouge_metric = evaluate.load('rouge')
    gradfree_batch_size = 128

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    print(f'*** step1_iter: {step1_iter} ***')

    devgolds = []
    with ProgressPrinter(f'{k} rouge (dev)') as printer:
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for examples, labels in dev:
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['article'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = gradfree_batch_size,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                indices = [ torch.topk(embeddings[a,:] @ embeddings[a+1:b,:].T, dim=0, k=k).indices for a, b in zip(splits, splits[1:]) ]
                prompts = [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in index.to('cpu').tolist() ])
                            for ex, index in zip(examples, indices) ]
                guesses = taskllm.generate(prompts)
                scores = rouge_metric.compute(predictions=guesses, references=labels)['rouge1']
                
                for ex, guess in zip(examples, guesses):
                    devgolds.append({ 'id': ex['id'], 'output': guess })

            printer.addobs(scores)

    with open(f'lamp4u_step1_dev_golds.json', 'w') as jsonfile:
        json.dump({ 'task': 'LaMP_4', 'golds': devgolds }, jsonfile)
            
step1_dev_set_labels(k=4, step1_iter='17_augment4')

*** step1_iter: 17_augment4 ***
n       4 rouge (dev) (since)      dt
1               0.164 (0.164)  3.28 s
2               0.201 (0.239)   5.4 s
4               0.208 (0.214)  10.1 s
8               0.208 (0.208)  18.8 s
16              0.209 (0.210)  34.6 s
32              0.192 (0.176)  1.13 m
64              0.184 (0.176)  2.25 m
128             0.190 (0.196)  4.48 m
241             0.188 (0.186)  8.35 m


Note: The step 2 dev golds are generated when preparing the submission files.