In [230]:
import json 
with open('/data3/hanseongkim/CBR/dataset/dev_score_all.json','r') as f:
    dev = json.load(f)
len(dev)

883

In [231]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
token = tokenizer(train[0]['question'] +  '</s>' + train[0]['program'])

In [4]:
tokenizer.decode(token['input_ids'])

'<s>what is the the interest expense in 2009?</s>divide(arg1, arg2), divide(arg1, #0), EOF</s>'

In [248]:
from copy import deepcopy
from tqdm import tqdm
import random
from torch.nn.utils.rnn import pad_sequence
def get_sample_with_program(data,tokenizer):
    output = []
    for i in tqdm(range(len(data)), desc= 'processing'):
        origin = tokenizer(data[i]['question'],max_length=256, padding=True, truncation=True, return_tensors='pt')
        candidates = tokenizer(data[i]['question'] + '</s>' + data[i]['program'],max_length=256, padding=True, truncation=True, return_tensors='pt')
        for key, val in origin.items():
            origin[key] = val.squeeze()
        for key, val in candidates.items():
            origin['cand_' + key] = val.squeeze()
        output.append(origin)
        # question으로 정렬 먼저 진행
        gold = sorted(data[i]['gold_index'],key = lambda x : x['question_score'],reverse=True)
        output[i]['gold_inds'] = deepcopy([item['index'] for item in gold])
        output[i]['origin_index'] = i
    return output

""" Data Loader """
class WithProgramDataLoader:
    def __init__(self, is_training, data, batch_size):
        self.data = data
        self.visited = [False] * len(self.data)
        self.batch_size = batch_size
        self.is_training = is_training
        self.data_size = len(self.data)
        self.count = 0
        self.normal_term = 256
        if self.is_training:
            self.num_batches = 0# int(self.data_size / batch_size) if self.data_size % batch_size == 0 else int(self.data_size / batch_size) + 1
            self.batch_sampled = self.preprocessing()
        else:
            self.num_batches = int(self.data_size / batch_size) if self.data_size % batch_size == 0 else int(self.data_size / batch_size) + 1


    def __iter__(self):
        return self

    def __next__(self):
        if self.is_training:
            self.count += 1
            if self.count < self.num_batches:
                current_batch = deepcopy(self.batch_sampled[self.count])
                random.shuffle(current_batch)
                output = self.gen_output(current_batch)
                self.count += 1
                return output
            else:
                raise StopIteration
        else:
            if self.count < self.num_batches:
                return self.eval_batch()
            else:
                raise StopIteration
            
        
    def __len__(self):
        return self.num_batches

    def reset(self):
        self.count = 0
        self.shuffle_all_data()

    def shuffle_all_data(self):
        self.visited = [False] * self.data_size
        return 

    def eval_batch(self):
        output = {
            "input_ids" : [], 
            "attention_mask" : [], 
            "cand_input_ids" : [], 
            "cand_attention_mask" : [], 
            'gold_inds': []
        }
        for i in range(self.count * self.batch_size , min((self.count + 1)* self.batch_size,self.data_size)):
            output['input_ids'].append(self.data[i]['input_ids'])
            output['attention_mask'].append(self.data[i]['attention_mask'])
            output['cand_input_ids'].append(self.data[i]['cand_input_ids'])
            output['cand_attention_mask'].append(self.data[i]['cand_attention_mask'])
            output['gold_inds'].append(self.data[i]['gold_inds'])
            self.visited[i] = True
        output['input_ids'] = pad_sequence(output['input_ids'],batch_first=True,padding_value = 1)
        output['attention_mask'] = pad_sequence(output['attention_mask'],batch_first=True,padding_value = 0)

        output['cand_input_ids'] = pad_sequence(output['cand_input_ids'],batch_first=True,padding_value = 1)
        output['cand_attention_mask'] = pad_sequence(output['cand_attention_mask'],batch_first=True,padding_value = 0)
        self.count += 1
        return output
    
    def preprocessing(self):
        sampled_inds = []
        # for _ in tqdm(range(self.num_batches),desc = 'data sampling'):
        batch_size = 0
        batch = []
        with tqdm(total= self.data_size) as pbar:
            while not all(self.visited):
                cur_batch = self.get_batching()
                if len(cur_batch) < 3:
                    break
                if len(batch) +len(cur_batch) < self.batch_size:
                    batch += cur_batch
                elif len(cur_batch) > self.batch_size:
                    for i in range(0,len(cur_batch), self.batch_size):
                        sampled_inds.append(cur_batch[i: i + self.batch_size])
                        batch_size += 1
                    pbar.update(len(cur_batch))
                elif len(batch) +len(cur_batch) >= self.batch_size:
                    for i in range(0,len(batch), self.batch_size):
                        sampled_inds.append(batch[i: i + self.batch_size])
                        batch_size += 1
                    pbar.update(len(batch))
                    batch = []
                    batch += cur_batch
                else:
                    if batch:
                        for i in range(0,len(batch), self.batch_size):
                            sampled_inds.append(batch[i: i + self.batch_size])
                            batch_size += 1
                        pbar.update(len(batch))
        self.num_batches = batch_size
        return sampled_inds
    
    def get_batching(self,trial = 0):
        batch_gold = []
        # while True:
            # sample = random.choice([idx for idx,item in enumerate(self.visited) if item == False]) # random에서 데이터 하나 뽑고
            # sample = [idx for idx,item in enumerate(self.visited) if item == False]
        org_cands = [self.data[idx]['origin_index'] for idx in range(len(self.data)) if not self.visited[idx]]
        sample = sorted(org_cands,key = lambda x: -len(self.data[x]['gold_inds']))[0] # max 에서  추출중 현재

        # if not self.visited[sample]: # 무조건 visited False 상황임 근데 ㅋㅋ
        self.visited[sample] = True
        batch = [sample]
        pos_gold = self.data[sample]['gold_inds']
        size = 1
        candidates = [self.data[i]['origin_index'] for i in org_cands if i not in pos_gold] # negative candidates
        candidates = sorted(candidates,key = lambda x: len(self.data[x]['gold_inds'])) # max len should be back for pop

        
        batch_gold.extend(pos_gold)
        while size < self.normal_term:
            # 여기에 buffer를 추가해 주자.
            if len(candidates) > 4: # buffer
                cand = candidates.pop(0) # max: (), min: (0) # middle: len(candidates)//2)
            else:
                if trial < 2:
                    trial += 1
                    for i in batch:
                        self.visited[i] = False
                    self.get_batching(trial = trial)
                break

            if not self.visited[cand]: # 방문한적이 없고 -> 학습한 적이 없고
                cand_pos = self.data[cand]['gold_inds']  # candidates 의 pos를 구해오고
                if not bool(set(batch_gold + batch) & set(cand_pos + [cand])): # 겹치는게 없다면 #해 gold랑 gold끼리만 비교하고 있구나
                    batch.append(cand)
                    self.visited[cand] = True
                    batch_gold.extend(cand_pos)
                    size += 1
                else:
                    continue
                # 최종 batch에 append
        for item in batch:
            if item in batch_gold:
                raise Exception(f"positive gold in batch {item},\n gold batch : {batch_gold}")
        return batch

    def gen_output(self,batch):        
        output = {
            "input_ids" : [], 
            "attention_mask" : [], 
            "pos_input_ids" : [], 
            "pos_attention_mask" : [], 
            'gold_inds': [],
            'origin_index': []
        }

        for i in batch: # list of index
            output['input_ids'].append(self.data[i]['input_ids'])
            output['attention_mask'].append(self.data[i]['attention_mask'])
            output['gold_inds'].append(self.data[i]['gold_inds'])
            output['origin_index'].append(self.data[i]['origin_index'])
            if self.data[i]['gold_inds']:
                pos = self.data[i]['gold_inds'][0]
                self.data[i]['gold_inds'] = self.data[i]['gold_inds'][1:] + [self.data[i]['gold_inds'][0]]
            else:
                pos = i # origin index
            output['pos_input_ids'].append(self.data[pos]['cand_input_ids'])
            output['pos_attention_mask'].append(self.data[pos]['cand_attention_mask'])
            
        output['input_ids'] = pad_sequence(output['input_ids'],batch_first=True,padding_value = 1)
        output['attention_mask'] = pad_sequence(output['attention_mask'],batch_first=True,padding_value = 0)
        output['pos_input_ids'] = pad_sequence(output['pos_input_ids'],batch_first=True,padding_value = 1)
        output['pos_attention_mask'] = pad_sequence(output['pos_attention_mask'],batch_first=True,padding_value = 0)
        
        return output
output = get_sample_with_program(dev,tokenizer)

processing:   0%|          | 0/883 [00:00<?, ?it/s]

processing: 100%|██████████| 883/883 [00:00<00:00, 3316.82it/s]


In [249]:
d_loader = WithProgramDataLoader(False, output,64) 
# 256, sample: max, search: min: 2665/
#  25%|██▌       | 1571/6251 [02:01<10:19,  7.55it/s]
# 2048 2160건 3:33
# 256, sample: max, search: min
#  25%|██▌       | 1571/6251 [02:01<10:19,  7.55it/s]
# 256, sample: max, search: middle
#   25%|██▍       | 1557/6251 [02:01<10:17,  7.60it/s]
# 128, sample: middle, search: middle
#  20%|█▉        | 1231/6251 [02:03<08:22,  9.99it/s]
# 128, sample: min, search: middle
#  17%|█▋        | 1077/6251 [02:02<09:50,  8.76it/s]
# 128, sample: max, search: middle
#  25%|██▍       | 1546/6251 [02:00<06:08, 12.78it/s]

In [250]:
for tt in tqdm(d_loader,total = len(d_loader)):
    tt

100%|██████████| 14/14 [00:00<00:00, 937.20it/s]


In [229]:
ss = d_loader.batch_sampled
# len(ss[4])0
length = []
for cl in ss:
    length.append(len(cl))
print(max(length), sum(length)//len(length),len(length),min(length),Counter(length).most_common())

64 52 37 1 [(55, 6), (54, 5), (60, 4), (64, 3), (63, 3), (59, 2), (45, 2), (61, 2), (1, 1), (10, 1), (42, 1), (38, 1), (37, 1), (62, 1), (47, 1), (58, 1), (53, 1), (56, 1)]


In [6]:
def find_clusters(output):
    clusters = []
    
    # 각 요소들 간의 공통 요소를 찾아 군집을 형성합니다.
    for item in tqdm(output):
        cur_gold_inds = item['gold_inds']
        cur_cluster = None
        
        # 현재 요소를 기존 군집에 추가할 수 있는지 검사합니다.
        for idx, cluster in enumerate(clusters):
            flag = False
            if check_overlap(cur_gold_inds, cluster):
                # 이미 다른 군집에 속한 경우 두 군집을 병합합니다.
                cur_gold_inds.extend(cluster)
                cur_gold_inds = list(set(cur_gold_inds))
                flag = True
            if flag:
                clusters.remove(cluster)
                
    
        clusters.append(cur_gold_inds)
    return clusters

# 겹치는지 확인하는 함수입니다.
def check_overlap(list1, list2):
    common_elements = set(list1) & set(list2)
    return bool(common_elements)

# 예시 output
# output = [{'gold_inds': [1, 2, 3]}, {'gold_inds': [2, 3, 4]}, {'gold_inds': [5, 6]}]

clusters = find_clusters(output)
print(len(clusters))


  0%|          | 0/6251 [00:00<?, ?it/s]

100%|██████████| 6251/6251 [00:12<00:00, 517.73it/s]

143





In [197]:
length = []
for cl in clusters:
    length.append(len(cl))
print(max(length), sum(length)//len(length),len(length),min(length),Counter(length).most_common())

2128 43 143 0 [(0, 52), (1, 32), (3, 12), (8, 5), (5, 4), (10, 3), (4, 3), (7, 3), (2, 2), (11, 2), (6, 1), (14, 1), (28, 1), (25, 1), (24, 1), (110, 1), (60, 1), (46, 1), (44, 1), (15, 1), (159, 1), (251, 1), (142, 1), (50, 1), (90, 1), (12, 1), (45, 1), (108, 1), (47, 1), (49, 1), (133, 1), (1327, 1), (185, 1), (884, 1), (2128, 1)]


In [44]:
clusters[0]

[[1, 2, 3], [2, 3, 4]]

In [35]:
cnt =0
for item in output:
    if item['origin_index'] in item['gold_inds']:
        cnt += 1
cnt, len(output)

(0, 6251)

In [172]:
visited = [False] * len(output)
org_cands = [output[idx]['origin_index'] for idx in range(len(output)) if not visited[idx]]
sample = sorted(org_cands,key = lambda x: -len(output[x]['gold_inds']))[len(output)//2] # max 에서  추출중 현재

# if not visited[sample]: # 무조건 visited False 상황임 근데 ㅋㅋ
visited[sample] = True
batch = [sample]
pos_gold = output[sample]['gold_inds']
size = 1
candidates = list(set(org_cands) - set(pos_gold + [sample]))
# [i for i in org_cands if not visited[i] and i not in pos_gold + [sample]] # negative candidates
candidates = sorted(candidates,key = lambda x: len(output[x]['gold_inds'])) # max len should be back for pop

In [173]:
len(pos_gold), len(candidates),len(output), len(org_cands),sample

(1325, 4925, 6251, 6251, 5171)

In [105]:
candidates[0]

1

In [121]:
len(set(pos_gold))
cand = 823
set(pos_gold + [2]) & set(output[cand]['gold_inds'] + [cand])

set()

In [174]:
batch_gold = []
batch_gold.extend(pos_gold)
while candidates:
    # 여기에 buffer를 추가해 주자.
    if len(candidates): # buffer
        cand = candidates.pop()
        # print(len(output[cand]['gold_inds']))
    else:
        print(len(batch))
        break

    if not visited[cand]: # 방문한적이 없고 -> 학습한 적이 없고
        cand_pos = output[cand]['gold_inds']  # candidates 의 pos를 구해오고
        # print('set(batch_gold) & set(cand_pos) & set(batch + [cand])', set(batch_gold) & set(cand_pos) & set(batch + [cand]), bool(set(batch_gold) & set(cand_pos) & set(batch + [cand])))
        if not bool(set(batch_gold + batch) & set(cand_pos + [cand])): # 겹치는게 없다면 #해 gold랑 gold끼리만 비교하고 있구나
            batch.append(cand)
            visited[cand] = True
            batch_gold.extend(cand_pos)
            batch_gold = list(set(batch_gold))
            size += 1
        else:
            continue
len(batch), len(batch_gold), size

(126, 6109, 126)

In [124]:
for item in batch:
    if item in batch_gold:
        print('item', item)
        break

In [21]:
len(candidates), max([len(item['gold_inds']) for item in output])

(4125, 2126)

In [215]:
ss = d_loader.batch_sampled
for cl in ss:
    length.append(len(cl))
print(max(length), sum(length)//len(length),len(length),min(length),Counter(length).most_common())

2128 1 15457 0 [(0, 13526), (1, 215), (4, 201), (2, 159), (3, 152), (5, 137), (6, 124), (7, 100), (8, 86), (11, 78), (64, 70), (10, 69), (12, 56), (9, 50), (14, 46), (13, 45), (18, 30), (15, 29), (16, 25), (17, 24), (21, 19), (19, 16), (20, 12), (24, 11), (23, 11), (22, 11), (28, 9), (25, 9), (29, 8), (26, 8), (33, 7), (27, 7), (32, 6), (31, 6), (30, 6), (38, 5), (37, 5), (34, 5), (45, 4), (36, 4), (35, 4), (46, 3), (44, 3), (50, 3), (47, 3), (49, 3), (42, 3), (41, 3), (40, 3), (39, 3), (60, 2), (59, 2), (55, 2), (54, 2), (53, 2), (52, 2), (51, 2), (48, 2), (43, 2), (110, 1), (159, 1), (251, 1), (142, 1), (90, 1), (108, 1), (133, 1), (1327, 1), (185, 1), (884, 1), (2128, 1), (63, 1), (62, 1), (61, 1), (58, 1), (57, 1), (56, 1)]


In [212]:
ss

[[328,
  343,
  411,
  588,
  607,
  889,
  1045,
  1157,
  1198,
  1202,
  1238,
  1486,
  1519,
  1783,
  1848,
  2142,
  2263,
  2325,
  2331,
  2413,
  2491,
  2594,
  2725,
  2741,
  2847,
  2897,
  3031,
  3040,
  3136,
  3333,
  3575,
  3616,
  3757,
  3865,
  3868,
  4185,
  4234,
  4397,
  4496,
  4500,
  4795,
  5031,
  5309,
  5370,
  5583,
  5625,
  5945,
  6182,
  6186,
  106,
  143,
  179,
  604,
  1326,
  1327,
  1394,
  1520,
  1800,
  2014,
  2418,
  2446,
  2514,
  2809,
  3005],
 [343,
  411,
  588,
  607,
  889,
  1045,
  1157,
  1198,
  1202,
  1238,
  1486,
  1519,
  1783,
  1848,
  2142,
  2263,
  2325,
  2331,
  2413,
  2491,
  2594,
  2725,
  2741,
  2847,
  2897,
  3031,
  3040,
  3136,
  3333,
  3575,
  3616,
  3757,
  3865,
  3868,
  4185,
  4234,
  4397,
  4496,
  4500,
  4795,
  5031,
  5309,
  5370,
  5583,
  5625,
  5945,
  6182,
  6186,
  106,
  143,
  179,
  604,
  1326,
  1327,
  1394,
  1520,
  1800,
  2014,
  2418,
  2446,
  2514,
  2809,
  3005,
  

In [209]:
from collections import Counter
cnt = []
d_loader.reset()
tt = 0
for batch in d_loader:
    b_num = batch['input_ids'].shape[0]
    cnt.append(b_num)
    if b_num == 32:
        tt += 1
print(sum(cnt)// len(cnt), max(cnt), min(cnt),tt, len(cnt),sum(cnt))
Counter(cnt).most_common()



RuntimeError: received an empty list of sequences

In [16]:
batch = next(d_loader)
batch['input_ids'].shape

torch.Size([2, 29])

In [None]:
def check_positive(batch,ep):
    cnt = 0
    res = []
    # print('batch len:', len(batch['input_ids']))
    for i in range(batch['input_ids'].shape[0]):
        sample_input_ids = batch['input_ids'][i]
        sample_input_ids = tokenizer.decode(sample_input_ids)
        positive = batch['pos_input_ids'][i]
        positive_text = tokenizer.decode(positive)
        if train[batch['origin_index'][i]]['gold_index'] == []:
            continue
        org_positive = sorted(train[batch['origin_index'][i]]['gold_index'],key=lambda x : -x['question_score'])[min(ep,len(batch['gold_inds'][i])-1)]['question']
        if org_positive.replace(' ','') in positive_text.replace('<s>','').replace('<pad>','').replace(' ','') or sample_input_ids in positive_text:
            cnt +=1
        else: # 없다면
            res.append((sample_input_ids,positive_text,org_positive,str(batch['origin_index'][i])))
    return cnt ,len(res)
batch = next(d_loader)
print(tokenizer.decode(batch['input_ids'][1]))
print(tokenizer.decode(batch['pos_input_ids'][1]))
print(sorted(train[batch['origin_index'][1]]['gold_index'],key=lambda x : -x['question_score'])[0]['question'])
print(batch['origin_index'][1])
print(train[batch['gold_inds'][0][1]]['question'])

<s>by how many percentage points did the health care cost trend rate for next year increase in 2017?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
<s>what was the difference in the initial health care trend rate and the ultimate trend rate in 2017?</s>subtract(arg1, arg2), EOF</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
what was the difference in the initial health care trend rate and the ultimate trend rate in 2017?
823
what was the operating expenses in 2007 in billions


In [12]:
# check length
output[0]['gold_inds']

[1291,
 5032,
 4346,
 1113,
 538,
 1589,
 496,
 795,
 3488,
 2139,
 5,
 3389,
 1624,
 3393,
 1102,
 4416,
 511,
 4353,
 2304,
 812,
 3036,
 4608,
 4552,
 4266]

In [14]:
from tqdm import tqdm
# random sample middle search 
random.seed(42)
d_loader = WithProgramDataLoader(True, output,64)
for i in range(1):
    d_loader.reset()
    left = 0
    for batch in tqdm(d_loader):
        cnt,res = check_positive(batch,i)
        left += res
    print(f'EPOCH : {i} left number {left}')

  0%|          | 0/98 [00:00<?, ?it/s]

100%|██████████| 98/98 [00:51<00:00,  1.89it/s]

EPOCH : 0 left number 0





In [15]:
from tqdm import tqdm
random.seed(42)
d_loader = WithProgramDataLoader(True, output,1)
for i in range(10):
    d_loader.reset()
    left = 0
    for batch in tqdm(d_loader):
        cnt,res = check_positive(batch,i)
        left += res
    print(f'EPOCH : {i} left number {left}')

# 대충 3분이었음.

  0%|          | 0/6251 [00:00<?, ?it/s]

100%|██████████| 6251/6251 [03:18<00:00, 31.53it/s]  


EPOCH : 0 left number 1676


  4%|▍         | 252/6251 [00:19<04:59, 20.03it/s]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fc8700cccd0>>
Traceback (most recent call last):
  File "/home/hanseongkim/anaconda3/envs/finqa/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
  7%|▋         | 421/6251 [00:32<05:27, 17.83it/s]

In [None]:
def get_split_batch(data,batch_size):
    visited = [False] * len(data)
    bucket = []

    for i in range(int(len(data)/batch_size) + 1): # simple get batch numbert using 
    while True:
        if all(visited):
            break
        cur_pos = []
        cands = [[idx,len(data[idx]['gold_inds'])] for idx in range(len(data)) if not visited[idx] and idx not in cur_pos]
        # pos_len = [len(data[i]['gold_inds']) for i in range(len(data)) if not visited[i]]
        cands = sorted(cands, key = lambda x : -x[1])
        max_idx, min_idx = cands[0][0], cands[-1][0]
        if len(set(data[max_idx]['gold_inds']) & set(data[min_idx]['gold_inds']) & set(cur_pos)) == 0:
            visited[max_idx] = True
            visited[min_idx] = True
            cur_pos.extend(data[max_idx]['gold_inds'])
            cur_pos.extend(data[min_idx]['gold_inds'])



In [112]:
from tqdm import tqdm
random.seed(42)
d_loader = WithProgramDataLoader(True, output,4)
for i in range(10):
    d_loader.reset()
    left = 0
    for batch in tqdm(d_loader):
        cnt,res = check_positive(batch,i)
        left += res
    print(f'EPOCH : {i} left number {left}')

  0%|          | 0/1563 [00:00<?, ?it/s]

 50%|█████     | 788/1563 [07:55<07:47,  1.66it/s]  


KeyboardInterrupt: 

In [None]:
res