Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError after inference #2

Closed
MFajcik opened this issue May 17, 2022 · 2 comments
Closed

TypeError after inference #2

MFajcik opened this issue May 17, 2022 · 2 comments

Comments

@MFajcik
Copy link

MFajcik commented May 17, 2022

Hi, when saving the inference results as json file via hover_inference.py, the dictionary contains set. Sets are not serializable via json. Thus the saving fails.

python -m hover_inference --root ./experiments/ --datadir . --index wiki17.hover.2bit
Traceback (most recent call last):                                                                                                                            
  File "xxx/conda/envs/colbert-v0.4/lib/python3.7/runpy.py", line 193, in _run_module_as_main                                                      
    "__main__", mod_spec)                                                                                                                                     
  File "xxx/.conda/envs/colbert-v0.4/lib/python3.7/runpy.py", line 85, in _run_code                                                                 
    exec(code, run_globals)                                                                                                                                   
  File "yyy/baleen/Baleen/hover_inference.py", line 53, in <module>                                                                             
    main(args)                                                                                                                                                
  File "yyy/baleen/Baleen/hover_inference.py", line 43, in main                                                                                 
    f.write(ujson.dumps(outputs) + '\n')                                                                                                                      
TypeError: {3910663, 1373715, 833561, 2479648, 3921953, 3408419, 3188274, 1399859, 372789, 1117238, 3283510, 3342401, 2585678, 1428049, 4948563, 1399892, 4449
365, 4216407, 4502103, 819287, 3598429, 5187684, 625781, 3042432, 1485442, 3487369, 4166284, 148110, 3713169, 1338005, 1951900, 936613, 437414, 556716, 266616
2, 573620, 4666549, 638144, 4154562, 4315335, 4230859, 4788429, 2613967, 174801, 4054227, 3768532, 5224152, 4914913, 2469090, 460517, 4820205, 1360625, 426418
5, 3064580, 424200, 4601613, 4707087, 2140434, 3422995, 3878677, 3583776, 2412329, 5212973, 3787053, 4286261, 2512694, 821559, 4174137, 3351359, 349002, 38961
43, 3414369, 875881, 1557358, 3957103, 4061041, 3913073, 2986353, 959347, 803705, 4757370, 1752441, 2359693, 4729260, 1178030, 1897903, 5206962, 564149, 42382
75, 4074960, 1900502, 4158425, 4635100, 4552679, 1106923, 3795442, 3049975, 2750972, 4602365, 1399295} is not JSON serializable

Every item in dictionary to be saved looks like this

0: ([(424200, 2), (4635100, 1), (4635100, 0)], 
{3910663, 1373715, 833561, 2479648, 3921953, 3408419, 3188274, 1399859, 372789, 1117238, 3283510, 3342401, 2585678, 1428049, 4948563, 1399892, 4449365, 4216407, 4502103, 819287, 3598429, 5187684, 625781, 3042432, 1485442, 3487369, 4166284, 148110, 3713169, 1338005, 19
51900, 936613, 437414, 556716, 2666162, 573620, 4666549, 638144, 4154562, 4315335, 4230859, 4788429, 2613967, 174801, 4054227, 3768532, 5224152, 4914913, 2469090, 460517, 4820205, 1360625, 4264185, 3064580, 424200, 4601613, 4707087, 2140434, 3422995, 3878677, 3583776, 2412329, 5212973, 3787053, 4286261, 2512694, 821559, 4174137, 3351359, 349002, 3896143, 3414369, 875881, 1557358, 3957103, 4061041, 3913073, 2986353, 959347, 803705, 4757370, 1752441, 2359693, 4729260, 1178030, 1897903, 5206962, 564149, 4238275, 4074960, 1900502, 4158425, 4635100, 4552679, 1106923, 3795442, 3049975, 2750972, 4602365, 1399295})

This is quite annoying, when spending few hours inferring the actual retrieval results :).
Cheers,
Martin

environment

name: colbert-v0.4
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_kmp_llvm
  - blas=2.114=mkl
  - blas-devel=3.9.0=14_linux64_mkl
  - bzip2=1.0.8=h7f98852_4
  - ca-certificates=2021.10.8=ha878542_0
  - cudatoolkit=11.1.1=h6406543_10
  - cupy=10.4.0=py37h52a254a_0
  - faiss=1.7.0=py37cuda111hcc9d9d6_8_cuda
  - faiss-gpu=1.7.0=h788eb59_8
  - ffmpeg=4.3=hf484d3e_0
  - freetype=2.10.4=h0708190_1
  - gmp=6.2.1=h58526e2_0
  - gnutls=3.6.13=h85f3911_1
  - jpeg=9b=h024ee3a_2
  - lame=3.100=h7f98852_1001
  - ld_impl_linux-64=2.36.1=hea4e1c9_2
  - libblas=3.9.0=14_linux64_mkl
  - libcblas=3.9.0=14_linux64_mkl
  - libfaiss=1.7.0=cuda111hf54f04a_8_cuda
  - libfaiss-avx2=1.7.0=cuda111h1234567_8_cuda
  - libffi=3.4.2=h7f98852_5
  - libgcc-ng=11.2.0=h1d223b6_16
  - libgfortran-ng=11.2.0=h69a702a_16
  - libgfortran5=11.2.0=h5c6108e_16
  - libiconv=1.16=h516909a_0
  - liblapack=3.9.0=14_linux64_mkl
  - liblapacke=3.9.0=14_linux64_mkl
  - libnsl=2.0.0=h7f98852_0
  - libpng=1.6.37=h21135ba_2
  - libstdcxx-ng=11.2.0=he4da1e4_16
  - libtiff=4.0.9=he6b73bb_1
  - libuv=1.43.0=h7f98852_0
  - libzlib=1.2.11=h166bdaf_1014
  - llvm-openmp=13.0.1=he0ac6c6_1
  - mkl=2022.0.1=h8d4b97c_803
  - mkl-devel=2022.0.1=ha770c72_804
  - mkl-include=2022.0.1=h8d4b97c_803
  - ncurses=6.3=h27087fc_1
  - nettle=3.6=he412f7d_0
  - ninja=1.10.2=h4bd325d_1
  - numpy=1.21.6=py37h976b520_0
  - olefile=0.46=pyh9f0ad1d_1
  - openh264=2.1.1=h780b84a_0
  - openssl=3.0.3=h166bdaf_0
  - pillow=5.4.1=py37h34e0f95_0
  - pip=21.0.1=pyhd8ed1ab_0
  - python=3.7.12=hf930737_100_cpython
  - python_abi=3.7=2_cp37m
  - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0
  - readline=8.1=h46c0cb4_0
  - setuptools=62.1.0=py37h89c1867_0
  - sqlite=3.38.2=h4ff8645_0
  - tbb=2021.5.0=h924138e_1
  - tk=8.6.12=h27826a3_0
  - torchaudio=0.9.0=py37
  - torchvision=0.10.0=py37_cu111
  - wheel=0.37.1=pyhd8ed1ab_0
  - xz=5.2.5=h516909a_1
  - zlib=1.2.11=h166bdaf_1014
  - pip:
    - anyio==3.5.0
    - argon2-cffi==21.3.0
    - argon2-cffi-bindings==21.2.0
    - attrs==21.4.0
    - babel==2.10.1
    - backcall==0.2.0
    - beautifulsoup4==4.11.1
    - bitarray==2.4.1
    - bleach==5.0.0
    - blis==0.7.7
    - catalogue==2.0.7
    - certifi==2021.10.8
    - cffi==1.15.0
    - charset-normalizer==2.0.12
    - click==8.0.4
    - cymem==2.0.6
    - debugpy==1.6.0
    - decorator==5.1.1
    - defusedxml==0.7.1
    - entrypoints==0.4
    - fastjsonschema==2.15.3
    - fastrlock==0.8
    - filelock==3.6.0
    - gitdb==4.0.9
    - gitpython==3.1.27
    - huggingface-hub==0.5.1
    - idna==3.3
    - importlib-metadata==4.11.3
    - importlib-resources==5.7.1
    - ipykernel==6.13.0
    - ipython==7.32.0
    - ipython-genutils==0.2.0
    - ipywidgets==7.7.0
    - jedi==0.18.1
    - jinja2==3.1.1
    - joblib==1.1.0
    - json5==0.9.6
    - jsonschema==4.4.0
    - jupyter==1.0.0
    - jupyter-client==7.3.0
    - jupyter-console==6.4.3
    - jupyter-core==4.10.0
    - jupyter-server==1.16.0
    - jupyterlab==3.3.4
    - jupyterlab-pygments==0.2.2
    - jupyterlab-server==2.13.0
    - jupyterlab-widgets==1.1.0
    - langcodes==3.3.0
    - markupsafe==2.1.1
    - matplotlib-inline==0.1.3
    - mistune==0.8.4
    - murmurhash==1.0.7
    - nbclassic==0.3.7
    - nbclient==0.6.0
    - nbconvert==6.5.0
    - nbformat==5.3.0
    - nest-asyncio==1.5.5
    - notebook==6.4.11
    - notebook-shim==0.1.0
    - packaging==21.3
    - pandocfilters==1.5.0
    - parso==0.8.3
    - pathy==0.6.1
    - pexpect==4.8.0
    - pickleshare==0.7.5
    - preshed==3.0.6
    - prometheus-client==0.14.1
    - prompt-toolkit==3.0.29
    - psutil==5.9.0
    - ptyprocess==0.7.0
    - pycparser==2.21
    - pydantic==1.8.2
    - pygments==2.12.0
    - pyparsing==3.0.8
    - pyrsistent==0.18.1
    - python-dateutil==2.8.2
    - pytz==2022.1
    - pyyaml==6.0
    - pyzmq==22.3.0
    - qtconsole==5.3.0
    - qtpy==2.0.1
    - regex==2022.4.24
    - requests==2.27.1
    - sacremoses==0.0.49
    - scipy==1.7.3
    - send2trash==1.8.0
    - six==1.16.0
    - smart-open==5.2.1
    - smmap==5.0.0
    - sniffio==1.2.0
    - soupsieve==2.3.2.post1
    - spacy==3.2.4
    - spacy-legacy==3.0.9
    - spacy-loggers==1.0.2
    - srsly==2.4.3
    - terminado==0.13.3
    - thinc==8.0.15
    - tinycss2==1.1.1
    - tokenizers==0.10.3
    - tornado==6.1
    - tqdm==4.64.0
    - traitlets==5.1.1
    - transformers==4.10.0
    - typer==0.4.1
    - typing-extensions==3.10.0.2
    - ujson==5.2.0
    - urllib3==1.26.9
    - wasabi==0.9.1
    - wcwidth==0.2.5
    - webencodings==0.5.1
    - websocket-client==1.3.2
    - widgetsnbextension==3.6.0
    - zipp==3.8.0
prefix: xxx/.conda/envs/colbert-v0.4
@okhat
Copy link
Collaborator

okhat commented May 18, 2022

Good catch! Can you cast the set to a list? That sounds like it'll fix this. If you make a pull request, I'll merge it.

@MFajcik
Copy link
Author

MFajcik commented Aug 19, 2022

@okhat Wouldn't it be better to return N deduplicated-lists (where N is number of hops) from COLBERT engine. So the individual retrieval results would have preserved order?

I would submit the pull-request for COLBERT, but I am not sure if this won't cause problems with some scripts you have.

edit:/ code-wise something like

from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser


class Baleen:
    def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
        self.collectionX = load_collectionX(collectionX_path)
        self.searcher = searcher
        self.condenser = condenser

    def search(self, query, num_hops, depth=100, verbose=False):
        assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
        k = depth // num_hops

        searcher = self.searcher
        condenser = self.condenser
        collectionX = self.collectionX

        facts = []
        stage1_preds = None
        context = None

        pids_bag = [[] for _ in range(num_hops)]

        for hop_idx in range(0, num_hops):
            ranking = list(zip(*searcher.search(query, context=context, k=depth)))
            ranking_ = []

            facts_pids = set([pid for pid, _ in facts])

            for pid, rank, score in ranking:
                # print(f'[{score}] \t\t {searcher.collection[pid]}')
                if len(ranking_) < k and pid not in facts_pids:
                    ranking_.append(pid)

                if len(pids_bag[hop_idx]) < k:
                    if all(pid not in pids_bag[hi] for hi in range(num_hops)):
                        pids_bag[hop_idx].append(pid)

            stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
            context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])

        assert sum(len(pids_per_hop) for pids_per_hop in pids_bag) == depth #//edit fixed assert

        return stage2_L3x, pids_bag, stage1_preds

@okhat okhat closed this as completed Jul 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants