# Step 1: fine-tune LLM using top result from (fixed) ranker

In [2]:
def launch():
    import os
    import StepOne
    import torch

    os.environ['MODEL_TYPE'] = 'xxl'
    os.environ['BATCH_SIZE'] = '1'
    os.environ['r'] = os.environ.get('r', '1')
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepOne.step_one,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 1 max_iteration = 5 model_type = xxl *********
1                0.000 (0.000)          2.657 (2.657)          0.250 (0.250)          0.250 (0.250)          0.000 (0.000)    12 s
2                0.000 (0.000)          2.722 (2.787)          0.125 (0.000)          0.125 (0.000)          0.000 (0.000)  27.8 s
4                0.000 (0.000)          2.410 (2.097)          0.156 (0.188)          0.154 (0.183)          0.000 (0.000)  52.4 s
8                0.000 (0.000)          2.187 (1.965)          0.127 (0.097)          0.107 (0.061)          0.000 (0.000)  1.72 m
16               0.000 (0.000)          2.328 (2.468)          0.161 (0.195)          0.195 (0.283)          0.000 (0.000)  3.55 m
32               0.000 (0.000)          2.336 (2.344)          0.181 (0.201)          0.213 (0.232)          0.000 (0.000)   6.9 m
64               0.000 (0.000)          2.352 (2.369)          0.205 (0.228)          0.217 (0.220)          0.000 (0.000)  13.7 m
128              

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def launch():
    import os
    import StepTwo
    import torch

    os.environ['MODEL_TYPE'] = 'xxl'
    os.environ['BATCH_SIZE'] = '2'
    os.environ['LEARN_BATCH_SIZE'] = '2'
    os.environ['GRAD_FREE_BATCH_SIZE'] = '16'
    os.environ['r'] = os.environ.get('r', '1')
    os.environ['STEP1_ITER'] = os.environ.get('STEP1_ITER', '0_augment2')
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepTwo.step_two,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 1 max_iteration = 5 model_type = xxl *********
n                    iter (since)            4 loss (since)           4 rouge (since)       4 ema rouge (since) 4 ema rouge (dev) (since)            nsamps (since)      dt
1                   0.000 (0.000)             0.478 (0.478)             0.229 (0.229)             0.229 (0.229)             0.000 (0.000)           128.000 (128.000)  3.99 m
2                   0.000 (0.000)             0.455 (0.433)             0.186 (0.143)             0.211 (0.194)             0.000 (0.000)           120.750 (113.500)  7.98 m
4                   0.000 (0.000)             0.510 (0.564)             0.297 (0.408)             0.304 (0.397)             0.000 (0.000)           110.250 (99.750)  14.1 m
8                   0.000 (0.000)             0.665 (0.821)             0.316 (0.335)             0.311 (0.318)             0.000 (0.000)            56.625 (3.000)  16.9 m
16                  0.000 (0.000)             0.677 (0.689)          

# Step 3: Prepare Submission Files

In [None]:
def prepare_submission(*, step2_iter, step1_iter, k):
    import evaluate
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNews import dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=8)
    test = test_loader(batch_size=8)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True)
    taskllm_model_id = f'User_keq{k}_t5xxl_step1_iter{step1_iter}'
    t5.load_adapter(taskllm_model_id, 'raw_taskllm')
    t5.load_adapter(taskllm_model_id, 'ema_taskllm')
    rhat_model_id = f'User_keq{k}_t5xxl_step2_iter{step2_iter}'
    t5.load_adapter(rhat_model_id, 'raw_rhat')
    t5.load_adapter(rhat_model_id, 'ema_rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_suffix="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_suffix="rhat", model_id=rhat_model_id)
    rouge_metric = evaluate.load('rouge')
    gradfree_batch_size = 128
    n_randos = 128

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0]-1, 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        safek = min(k, scores.shape[0])
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=safek).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]
    
    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} ***')

    devgolds, testgolds = [], []
    with ProgressPrinter(f'{k} rouge (dev)') as printer:
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for isdev, (examples, labels) in interleave(dev, test, sequential=True):
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['article'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = gradfree_batch_size,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                randos = [ randomized_similarity(embeddings[a:b,:], n_randos) for a, b in zip(splits, splits[1:]) ]
                prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ])
                              for indices in rando.to('cpu').tolist()
                            ]
                            for ex, rando in zip(examples, randos)
                          ]
                rhats = torch.cat(inner_batch(func = rewardpredictor.predict,
                                              inner_batch_size = gradfree_batch_size,
                                              inputs = (sum(prompts, []),)
                                             ),
                                  dim=0)
                splits = cumsum(map(len, prompts))
                greedyaction = [ rhats[a:b].argmax().item() for a, b in zip(splits, splits[1:]) ]
                greedyprompts = [ prompt[a] for prompt, a in zip(prompts, greedyaction) ]
                guesses = sum(inner_batch(func = taskllm.generate,
                                          inner_batch_size = gradfree_batch_size,
                                          inputs = (greedyprompts,)
                                         ),
                              [])
                if isdev:
                    rewards = sum( ( rouge_metric.compute(predictions=[guess],
                                                          references=[label],
                                                          use_aggregator=False)['rouge1']
                                     for guess, label in zip(guesses, labels)
                                  ),
                                  [])
                    rewards = torch.Tensor(rewards, device='cpu').mean().item()
                else:
                    rewards = None
                
                for ex, guess in zip(examples, guesses):
                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': guess })

            printer.addobs(rewards)

    for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
        with open(f'lamp4u_{wut}golds_t5xxl_keq{k}_step1_iter{step1_iter}_step2_iter{step2_iter}.json', 'w') as jsonfile:
            json.dump({ 'task': 'LaMP_4', 'golds': golds }, jsonfile)
            
prepare_submission(k=4, step1_iter='0_augment2', step2_iter='0_augment1')

*** step1_iter: 0_augment2 step2_iter: 0_augment1 ***
n       4 rouge (dev) (since)      dt
1               0.114 (0.114)  1.02 m
2               0.181 (0.249)  1.79 m
4               0.217 (0.254)  3.68 m
8               0.219 (0.220)  7.55 m
16              0.232 (0.245)  15.7 m
32              0.222 (0.212)  32.9 m
64              0.217 (0.212)   1.1 h
128             0.224 (0.230)  2.23 h
256             0.228 (0.232)  4.47 h


> Hi,
> 
> This is the result of your latest submission to LaMP benchmark
> 
> {"rouge-1": 0.22116036597254493, "rouge-L": 0.20283352960160614}