# Step 1: fine-tune LLM using top result from (fixed) ranker

In [7]:
def step_one():
    from TaskLLM import PeftTaskLLM
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5)

    k = 1
    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        acc = (taskllm.predict(inputs, augment=train.swap_refs).argmax(dim=1) == target).float().mean().item()
    
                    loss = taskllm.learn(inputs, target, augment=train.swap_refs) if istrain else None
                    printer.addobs(iteration, loss, acc if istrain else None, acc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                taskllm.save_pretrained('User_keq1_t5base_step1')

from Fork import SubProcess
with SubProcess() as process: process.parent or step_one()

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.708       0.708           0           0           0           0       0.946
2                     0           0        0.69       0.673         0.5           1           0           0        1.42
4                     0           0        0.68       0.661       0.667           1           1           1        1.95
8                     0           0       0.693       0.703       0.571         0.5           1           0        3.29
16                    0           0       0.696         0.7       0.577       0.583         0.5        0.25         5.6
32                    0           0       0.694       0.693       0.538         0.5        0.75           1        10.7
64                    0           0        0.69       0.686       0.588        0.64       0.654       0.571        20.4
128                   0           0     

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [13]:
def learn_ranker(rank):
    from RewardPredictor import PeftRewardPredictor
    from TaskLLM import PeftTaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = PeftConfig.from_pretrained('User_keq1_t5base_step1')
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5, model_id='User_keq1_t5base_step1')

    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftRewardPredictor(rhat_peft_config, t5=rhat_t5)
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0        0.71        0.71           1           1           0           0        1.23
2                     0           0       0.694       0.678         0.5           0           0           0        1.91
4                     0           0       0.683        0.66       0.667           1           1           1        2.98
8                     0           0       0.677       0.673       0.857           1           1           0        5.56
16                    0           0       0.671       0.664       0.769       0.667       0.667         0.5        10.2
32                    0           0       0.684       0.697       0.769       0.769       0.667       0.667        20.1
64                    0           0       0.665       0.645       0.706        0.64       0.692       0.714        39.6
128                   0           0     

In [14]:
# logit sum expit ... almost identical
def learn_ranker(rank):
    from RewardPredictor import PeftRewardPredictor
    from TaskLLM import PeftTaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = PeftConfig.from_pretrained('User_keq1_t5base_step1')
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5, model_id='User_keq1_t5base_step1')

    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftRewardPredictor(rhat_peft_config, t5=rhat_t5)
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0        0.71        0.71           1           1           0           0        1.21
2                     0           0       0.694       0.678         0.5           0           0           0         1.9
4                     0           0       0.683        0.66       0.667           1           1           1        3.01
8                     0           0       0.677       0.673       0.857           1           1           0         5.7
16                    0           0       0.671       0.664       0.769       0.667       0.667         0.5        10.4
32                    0           0       0.684       0.697       0.769       0.769       0.667       0.667        20.5
64                    0           0       0.665       0.645       0.706        0.64       0.692       0.714          40
128                   0           0     

<class 'KeyboardInterrupt'>

KeyboardInterrupt: 



  File "/tmp/ipykernel_783552/4133691015.py", line 105, in <module>
    with SubProcess() as process: process.parent or learn_ranker(rank)
  File "/tmp/ipykernel_783552/4133691015.py", line 90, in learn_ranker
    loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
  File "/home/pmineiro/lampstuff/personalized_citation/RewardPredictor.py", line 62, in learn
    self._optim.zero_grad()
  File "/home/pmineiro/miniconda3/envs/lampstuff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pmineiro/lampstuff/personalized_citation/RewardPredictor.py", line 45, in forward
    augment_outputs = self(augment(data))
  File "/home/pmineiro/miniconda3/envs/lampstuff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pmineiro/lampstuff/personalized_citation/RewardPredictor.py", line 4

In [16]:
# with data doubling ...
def learn_ranker(rank):
    from RewardPredictor import PeftRewardPredictor
    from TaskLLM import PeftTaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = PeftConfig.from_pretrained('User_keq1_t5base_step1')
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5, model_id='User_keq1_t5base_step1')

    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftRewardPredictor(rhat_peft_config, t5=rhat_t5)
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0       0.702       0.702           0           0           0           0        1.21
2                     0           0       0.701       0.701           0           0           0           0         1.9
4                     0           0       0.693       0.684         0.5           1           0           0         3.3
8                     0           0       0.672       0.646       0.571       0.667           1           1        5.73
16                    0           0       0.704       0.732       0.533         0.5           1           0        11.2
32                    0           0       0.653       0.598       0.621       0.714           1           1        21.3
64                    0           0        0.64       0.626       0.667       0.714       0.714         0.5        41.7
128                   0           0     

Bad pipe message: %s [b'\xeaC\t\x05MX\xcb\xc8\x17\x14mLc\x9e\xb0\xed\xe1\xf3 2\x08\x9e\t"\x02\x84\xfa*\xd9\xaf\x83\x07\x8b\nXl\x0fxaa\x83\x9d\x1f\x95\xfe\xff\x83j\xf8\xbf_\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 h\x06\xee\xa8\x12\x03*\x8c']
Bad pipe message: %s [b"\xe4o\xf2\x9b\n./Ef\x10\xb9H\xe4\x8b\x1a\xa7!\x1a\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\x

10931                 0           0       0.632       0.629       0.721       0.728       0.705       0.732    6.82e+03
12180             0.103           1       0.632           0       0.721           0       0.711       0.717    7.33e+03


In [None]:
# multiple training passes ...
def learn_ranker(rank):
    from RewardPredictor import PeftRewardPredictor
    from TaskLLM import PeftTaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = PeftConfig.from_pretrained('User_keq1_t5base_step1')
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5, model_id='User_keq1_t5base_step1')

    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftRewardPredictor(rhat_peft_config, t5=rhat_t5)
    def reward_augment(inputs):
        import re
        return [ re.sub(r'Ref1: (.*)\nRef2: (.*)\nExtra:',
                        r'Ref1: \2\nRef2: \1\nExtra:',
                        z)
                 for z in inputs ]

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        max_iteration = 4
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration + 1 < max_iteration or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts, augment=reward_augment)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards, augment=reward_augment) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration + 1 == max_iteration:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0        0.71        0.71           1           1           0           0        1.28
2                     0           0       0.694       0.678         0.5           0           0           0        2.08
4                     0           0       0.683        0.66       0.667           1           1           1         3.2
8                     0           0       0.677       0.673       0.857           1           1           0        5.95
16                    0           0       0.671       0.664       0.769       0.667       0.667         0.5        10.7
32                    0           0       0.684       0.697       0.769       0.769       0.667       0.667          21
64                    0           0       0.665       0.645       0.706        0.64       0.692       0.714        40.8
128                   0           0     

In [17]:
# with data doubling and no reward augment ...
def learn_ranker(rank):
    from RewardPredictor import PeftRewardPredictor
    from TaskLLM import PeftTaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = PeftConfig.from_pretrained('User_keq1_t5base_step1')
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm = PeftTaskLLM(peft_config, t5=t5, model_id='User_keq1_t5base_step1')

    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftRewardPredictor(rhat_peft_config, t5=rhat_t5)

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} acc', f'{rank} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=rank).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                            
                            guesses = taskllm.predict(prompts, augment=train.swap_refs).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_rankeq{rank}')

from Fork import SubProcess
for rank in range(8, 9):
    with SubProcess() as process: process.parent or learn_ranker(rank)

n                  iter       since      8 loss       since       8 acc       since 8 acc (dev)       since      dt (s)
1                     0           0       0.702       0.702           0           0           0           0        1.19
2                     0           0       0.701       0.701           0           0           0           0        1.68
4                     0           0       0.693       0.684         0.5           1           0           0        2.69
8                     0           0       0.672       0.646       0.571       0.667           1           1        4.43
16                    0           0       0.704       0.732       0.533         0.5           1           0        8.29
32                    0           0       0.653       0.598       0.621       0.714           1           1        15.7
64                    0           0        0.64       0.626       0.667       0.714       0.714         0.5        30.6
128                   0           0     