# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def launch():
    import os
    import StepOne
    import torch

    os.environ['MODEL_TYPE'] = 'base'
    os.environ['BATCH_SIZE'] = '32'
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepOne.step_one,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 8 max_iteration = 5 model_type = base *********
n              iter (since)      4 loss (since)       4 acc (since)   4 ema acc (since) 4 acc (dev) (since)      dt
1             0.000 (0.000)       1.776 (1.776)       0.594 (0.594)       0.594 (0.594)       0.000 (0.000)  16.1 s
2             0.000 (0.000)       1.814 (1.852)       0.578 (0.562)       0.562 (0.531)       0.000 (0.000)  27.3 s
4             0.000 (0.000)       1.856 (1.898)       0.539 (0.500)       0.648 (0.734)       0.000 (0.000)    50 s
8             0.000 (0.000)       1.335 (0.814)       0.629 (0.719)       0.691 (0.734)       0.000 (0.000)  1.63 m
16            0.000 (0.000)       1.112 (0.889)       0.684 (0.738)       0.707 (0.723)       0.000 (0.000)  3.23 m
32            0.000 (0.000)       0.960 (0.809)       0.715 (0.746)       0.718 (0.729)       0.000 (0.000)  6.56 m
64            0.000 (0.000)       0.832 (0.703)       0.751 (0.787)       0.749 (0.780)       0.000 (0.000)  13.1 m
128  

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def launch():
    import os
    import StepTwo
    import torch

    os.environ['MODEL_TYPE'] = 'base'
    os.environ['BATCH_SIZE'] = '32'
    os.environ['LEARN_BATCH_SIZE'] = '16'
    os.environ['STEP1_ITER'] = '1_augment8'
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepTwo.step_two,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 8 max_iteration = 5 model_type = base *********
n              iter (since)      4 loss (since)       4 acc (since)   4 acc ema (since) 4 acc (dev) (since)      nsamps (since)      dt
1             0.000 (0.000)       0.250 (0.250)       0.875 (0.875)       0.875 (0.875)       0.000 (0.000)     203.281 (203.281)  4.81 m
2             0.000 (0.000)       0.362 (0.473)       0.828 (0.781)       0.828 (0.781)       0.000 (0.000)     103.984 (4.688)  6.79 m
4             0.000 (0.000)       0.388 (0.415)       0.836 (0.844)       0.852 (0.875)       0.000 (0.000)      55.250 (6.516)  10.7 m
8             0.000 (0.000)       0.413 (0.438)       0.844 (0.852)       0.852 (0.852)       0.000 (0.000)      30.805 (6.359)  18.4 m
16            0.000 (0.000)       0.437 (0.461)       0.826 (0.809)       0.840 (0.828)       0.000 (0.000)      18.604 (6.402)  33.7 m
32            0.000 (0.000)       0.460 (0.483)       0.802 (0.777)       0.804 (0.768)       0.000 (0.000)      12

# Step 3: Prepare Submission Files

In [None]:
def prepare_submission_probensemble(*, nvoters, step2_iter, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNewsCat import dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=8)
    test = test_loader(batch_size=8)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'raw_taskllm')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'ema_taskllm')
    t5.load_adapter(f'User_keq{k}_t5base_step2_iter{step2_iter}', 'raw_rhat')
    t5.load_adapter(f'User_keq{k}_t5base_step2_iter{step2_iter}', 'ema_rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_suffix="taskllm", choices=dev.choices)
    rewardpredictor = RewardPredictor(t5=t5, adapter_suffix="rhat", model_id=f'User_keq{k}_t5base_step2_iter{step2_iter}')

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0]-1, 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        safek = min(k, scores.shape[0])
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=safek).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} nvoters {nvoters} ***')
    
    with ProgressPrinter(f'{k} acc (dev)') as printer:
        devgolds, testgolds = [], []
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for isdev, (examples, labels) in interleave(dev, test, sequential=True):
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['article'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = 128,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                randos = [ randomized_similarity(embeddings[a:b,:], 64) for a, b in zip(splits, splits[1:]) ]
                prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ])
                              for indices in rando.to('cpu').tolist()
                            ]
                            for ex, rando in zip(examples, randos)
                          ]
                rhats = torch.cat(inner_batch(func = rewardpredictor.predict,
                                              inner_batch_size = 128,
                                              inputs = (sum(prompts, []),)
                                             ),
                                  dim=0)
                splits = cumsum(map(len, prompts))
                votingprompts = [ [ prompt[v] for v in torch.topk(rhats[a:b].view(-1), k=min(nvoters, b-a)).indices.to('cpu').tolist() ]
                                    for a, b, prompt in zip(splits, splits[1:], prompts)
                                ]
                predicts = torch.cat(inner_batch(func = taskllm.predict,
                                                 inner_batch_size = 128,
                                                 inputs = (sum(votingprompts, []),)
                                                ),
                                     dim=0)
                splits = cumsum(map(len, votingprompts))
                guesses = torch.cat([ predicts[a:b,:].logsumexp(dim=0, keepdim=True).argmax(dim=1)
                                      for a, b in zip(splits, splits[1:])
                                    ],
                                    dim=0)

                if isdev:
                    targets = [ dev.choices.index(label) for label in labels ]
                    targets = torch.Tensor(targets).long().to(guesses.device)
                    acc = (guesses == targets).float().mean().item()
                else:
                    acc = None

                for ex, guess in zip(examples, guesses):
                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': dev.choices[guess] })

            printer.addobs(acc)

        printer.print()
        printer.autoprint = False

        for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
            with open(f'lamp2u_{wut}golds_t5base_keq{k}_step1_iter{step1_iter}_step2_iter{step2_iter}_nvoters{nvoters}.json', 'w') as jsonfile:
                json.dump({ 'task': 'LaMP_2', 'golds': golds }, jsonfile)
            
for nvoters in [1, 3, 5, 7, 9, 11]:
    prepare_submission_probensemble(k=4, step1_iter='1_augment8', step2_iter='1_augment8', nvoters=nvoters)

*** step1_iter: 1_augment8 step2_iter: 1_augment8 nvoters 1 ***
n       4 acc (dev) (since)      dt
1             0.625 (0.625)  3.79 s
2             0.688 (0.750)  7.05 s
4             0.781 (0.875)  12.7 s
8             0.828 (0.875)  24.1 s
16            0.859 (0.891)  45.4 s
32            0.852 (0.844)  1.48 m
64            0.838 (0.824)  2.96 m
128           0.834 (0.830)  5.99 m
256           0.836 (0.906)    13 m
*** step1_iter: 1_augment8 step2_iter: 1_augment8 nvoters 3 ***
n       4 acc (dev) (since)      dt
1             0.625 (0.625)  3.24 s
2             0.688 (0.750)  6.58 s
4             0.781 (0.875)  12.4 s
8             0.828 (0.875)    24 s
16            0.859 (0.891)  45.8 s
32            0.859 (0.859)  1.51 m
64            0.830 (0.801)  3.03 m
128           0.834 (0.838)  6.14 m
256           0.835 (0.875)  13.3 m
*** step1_iter: 1_augment8 step2_iter: 1_augment8 nvoters 5 ***
n       4 acc (dev) (since)      dt
1             0.750 (0.750)  3.32 s
2             0.

*** step1_iter: 1_augment8 step2_iter: 1_augment8 nvoters 9 ***
n       4 acc (dev) (since)      dt
1             0.750 (0.750)  4.06 s
2             0.750 (0.750)  7.63 s
4             0.812 (0.875)  13.8 s
8             0.844 (0.875)  26.2 s
16            0.867 (0.891)  49.5 s
32            0.859 (0.852)  1.63 m
64            0.830 (0.801)  3.27 m
128           0.836 (0.842)  6.61 m
256           0.837 (0.875)  14.2 m
*** step1_iter: 1_augment8 step2_iter: 1_augment8 nvoters 11 ***
n       4 acc (dev) (since)      dt
1             0.750 (0.750)  3.54 s
2             0.750 (0.750)  7.17 s
4             0.812 (0.875)  13.5 s
8             0.844 (0.875)  26.2 s
16            0.867 (0.891)  50.3 s
32            0.859 (0.852)  1.67 m
64            0.832 (0.805)  3.35 m
128           0.836 (0.840)  6.76 m
256           0.837 (0.875)  14.5 m
