# Step 1 Dev Set Labels

In [None]:
def step1_dev_set_labels(*, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    print(f'*** step1_iter: {step1_iter} ***')
    
    with ProgressPrinter(f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        devgolds = []
        
        for examples, labels in dev:
            greedyrewards = []
            for ex, label in zip(examples, labels):
                with torch.no_grad():
                    targets, guesses = [], []
    
                    for ex, label in zip(examples, labels):
                        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                [ v['title'] 
                                                  for v in ex['profile']
                                                  if v['title'] != ex['title'] 
                                                ])
                        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                        index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        prompt = dev.append_to_title(ex, concat_titles)
                        guess = taskllm.predict([prompt], augment=dev.swap_refs).argmax(dim=1).item()
                        devgolds.append({ 'id': ex['id'], 'output': "[2]" if guess else "[1]" })
                        guesses.append(guess)
                        targets.append(int(label == dev.choices[1]))

            greedyacc = (torch.Tensor(guesses) == torch.Tensor(targets)).float().mean().item()

            printer.addobs(greedyacc)

        printer.print()
        printer.autoprint = False

        with open(f'lamp1u_xxl_step1_dev_golds.json', 'w') as jsonfile:
            json.dump({ 'task': 'LaMP_1', 'golds': devgolds }, jsonfile)

step1_dev_set_labels(k=5, step1_iter=0)

*** step1_iter: 0 ***
n       5 acc (dev) (since)      dt
1             1.000 (1.000)  5.22 s
2             1.000 (1.000)  8.75 s
4             0.750 (0.500)  15.5 s
8             0.688 (0.625)  28.9 s
16            0.688 (0.688)    56 s
32            0.750 (0.812)  1.83 m
64            0.750 (0.750)  3.57 m
128           0.734 (0.719)  7.04 m
256           0.740 (0.746)    14 m
512           0.745 (0.750)  27.8 m
1024          0.754 (0.763)  54.7 m
1250          0.757 (0.770)   1.1 h


In [None]:
def step2_dev_set_labels(*, step2_iter, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    dev = dev_loader(batch_size=2)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')
    t5.load_adapter(f'User_keq{k}_t5xxl_step2_iter{step2_iter}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq{k}_t5xxl_step2_iter{step2_iter}')

    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ ex['ref1'], ex['ref2'] ] + 
                                [ v['title'] 
                                  for v in ex['profile']
                                  if v['title'] != ex['title'] 
                                ])
        scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
        temperature = scores[0].item() - scores[4].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} ***')
    nvoters = 1
    
    with ProgressPrinter(f'{k} acc (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        devgolds = []
        
        for examples, labels in dev:
            greedyrewards = []
            for ex, label in zip(examples, labels):
                with torch.no_grad():
                    randos = randomized_similarity(ex, 64)
                    
                    rhatprompts = []
                    prompts = []
                    for rando in randos:
                        titles = [ f'{ex["profile"][ind]["title"]}' for ind in rando ]
                        concat_titles = ' and '.join([f'"{v}"' for v in titles])
                        prompt = dev.append_to_title(ex, concat_titles)
                        prompts.append(prompt)
                        rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\n" + '\n'.join(
                                   [ f"Extra: {t}" for t in titles ])
                        rhatprompts.append(rhatprompt)

                    rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                    voters = torch.topk(rhats, k=min(nvoters, len(rhats)), dim=0).indices.view(-1).to('cpu').tolist()
                    guess = taskllm.predict([ prompts[v] for v in voters ], augment=dev.swap_refs).logsumexp(dim=0, keepdim=True).argmax(dim=1)
                        
                    target = int(label == dev.choices[1])
                    reward = int(guess.item() == target)
                    greedyrewards.append(reward)

                    devgolds.append({ 'id': ex['id'], 'output': "[2]" if guess else "[1]" })

            greedyacc = torch.Tensor(greedyrewards).float().mean().item()

            printer.addobs(greedyacc)

        printer.print()
        printer.autoprint = False

        with open(f'lamp1u_xxl_step2_dev_golds.json', 'w') as jsonfile:
            json.dump({ 'task': 'LaMP_1', 'golds': devgolds }, jsonfile)

step2_dev_set_labels(k=5, step1_iter=0, step2_iter=2)

*** step1_iter: 0 step2_iter: 2 ***
n       5 acc (dev) (since)      dt
1             0.500 (0.500)  13.2 s
2             0.750 (1.000)  24.7 s
4             0.750 (0.750)  43.2 s
8             0.688 (0.625)  1.34 m
16            0.719 (0.750)  2.78 m
32            0.781 (0.844)  5.79 m
64            0.781 (0.781)  11.4 m
128           0.777 (0.773)  22.7 m
256           0.777 (0.777)  45.4 m
512           0.769 (0.760)  1.52 h
1024          0.777 (0.785)  3.04 h
1250          0.775 (0.765)  3.71 h
