Skip to content

Commit 864d79e

Browse files
committed
Add repositories, authors and frontpage like search.
1 parent ff7afe7 commit 864d79e

File tree

6 files changed

+271
-6
lines changed

6 files changed

+271
-6
lines changed

paperswithcode/client.py

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from paperswithcode.models import (
1010
Paper,
1111
Papers,
12+
Repository,
1213
Repositories,
14+
PaperRepos,
15+
Author,
16+
Authors,
1317
Conference,
1418
Conferences,
1519
Proceeding,
@@ -69,7 +73,7 @@ def __parse(url: str) -> int:
6973
return 1
7074
else:
7175
q = parse.parse_qs(p.query)
72-
return q.get("page", [1])[0]
76+
return int(q.get("page", [1])[0])
7377

7478
@classmethod
7579
def __page(cls, result, page_model):
@@ -86,6 +90,34 @@ def __page(cls, result, page_model):
8690
results=result["results"],
8791
)
8892

93+
@handler
94+
def search(
95+
self,
96+
q: Optional[str] = None,
97+
page: int = 1,
98+
items_per_page: int = 50,
99+
) -> PaperRepos:
100+
"""Search in a similar fashion to the frontpage search.
101+
102+
Args:
103+
q (str, optional): Filter papers by querying the paper title and
104+
abstract.
105+
page (int): Desired page.
106+
items_per_page (int): Desired number of items per page.
107+
Default: 50.
108+
109+
Returns:
110+
PaperRepos: PaperRepos object.
111+
"""
112+
params = self.__params(page, items_per_page)
113+
timeout = None
114+
if q is not None:
115+
params["q"] = q
116+
return self.__page(
117+
self.http.get("/search/", params=params, timeout=timeout),
118+
PaperRepos,
119+
)
120+
89121
@handler
90122
def paper_list(
91123
self,
@@ -243,6 +275,166 @@ def paper_result_list(
243275
Results,
244276
)
245277

278+
@handler
279+
def repository_list(
280+
self,
281+
q: Optional[str] = None,
282+
owner: Optional[str] = None,
283+
name: Optional[str] = None,
284+
stars: Optional[int] = None,
285+
framework: Optional[str] = None,
286+
page: int = 1,
287+
items_per_page: int = 50,
288+
) -> Papers:
289+
"""Return a paginated list of repositories.
290+
291+
Args:
292+
q (str, optional): Search all searchable fields.
293+
owner (str, optional): Filter repositories by owner.
294+
name (str, optional): Filter repositories by name.
295+
stars (int, optional): Filter repositories by minimum number of
296+
stars.
297+
framework (str, optional): Filter repositories by framework.
298+
Available values: tf, pytorch, mxnet, torch, caffe2, jax,
299+
paddle, mindspore.
300+
page (int): Desired page.
301+
items_per_page (int): Desired number of items per page.
302+
Default: 50.
303+
304+
Returns:
305+
Repositories: Repositories object.
306+
"""
307+
params = self.__params(page, items_per_page)
308+
309+
if q is not None:
310+
params["q"] = q
311+
if owner is not None:
312+
params["owner"] = owner
313+
if name is not None:
314+
params["name"] = name
315+
if stars is not None:
316+
params["stars"] = str(stars)
317+
if framework is not None:
318+
params["framework"] = framework
319+
return self.__page(
320+
self.http.get("/repositories/", params=params),
321+
Repositories,
322+
)
323+
324+
@handler
325+
def repository_owner_list(self, owner: str) -> Repositories:
326+
"""List all repositories for a specific repo owner.
327+
328+
Args:
329+
owner (str): Repository owner.
330+
331+
Returns:
332+
Repositories: Repositories object.
333+
"""
334+
return self.__page(
335+
self.http.get(f"/repositories/{owner}"),
336+
Repositories,
337+
)
338+
339+
@handler
340+
def repository_get(self, owner: str, name: str) -> Repository:
341+
"""Return a repository by it's owner/name pair.
342+
343+
Args:
344+
owner (str): Owner name.
345+
name (str): Repository name.
346+
347+
Returns:
348+
Repository: Repository object.
349+
"""
350+
return Repository(**self.http.get(f"/repositories/{owner}/{name}/"))
351+
352+
@handler
353+
def repository_paper_list(
354+
self, owner: str, name: str, page: int = 1, items_per_page: int = 50
355+
) -> Papers:
356+
"""List all papers connected to the repository.
357+
358+
Args:
359+
owner (str): Owner name.
360+
name (str): Repository name.
361+
page (int): Desired page.
362+
items_per_page (int): Desired number of items per page.
363+
Default: 50.
364+
365+
Returns:
366+
Papers: Papers object.
367+
"""
368+
params = self.__params(page, items_per_page)
369+
return self.__page(
370+
self.http.get(
371+
f"/repositories/{owner}/{name}/papers/", params=params
372+
),
373+
Papers,
374+
)
375+
376+
@handler
377+
def author_list(
378+
self,
379+
q: Optional[str] = None,
380+
full_name: Optional[str] = None,
381+
page: int = 1,
382+
items_per_page: int = 50,
383+
) -> Authors:
384+
"""Return a paginated list of paper authors.
385+
386+
Args:
387+
q (str, optional): Search all searchable fields.
388+
full_name (str, optional): Filter authors by part of their full
389+
name.
390+
page (int): Desired page.
391+
items_per_page (int): Desired number of items per page.
392+
Default: 50.
393+
394+
Returns:
395+
Repositories: Repositories object.
396+
"""
397+
params = self.__params(page, items_per_page)
398+
399+
if q is not None:
400+
params["q"] = q
401+
if full_name is not None:
402+
params["full_name"] = full_name
403+
return self.__page(self.http.get("/authors/", params=params), Authors)
404+
405+
@handler
406+
def author_get(self, author_id: str) -> Author:
407+
"""Return a specific author selected by its id.
408+
409+
Args:
410+
author_id (str): Author id.
411+
412+
Returns:
413+
Author: Author object.
414+
"""
415+
return Author(**self.http.get(f"/authors/{author_id}/"))
416+
417+
@handler
418+
def author_paper_list(
419+
self, author_id: str, page: int = 1, items_per_page: int = 50
420+
) -> Papers:
421+
"""List all papers connected to the author.
422+
423+
Args:
424+
author_id (str): Author id.
425+
page (int): Desired page.
426+
items_per_page (int): Desired number of items per page.
427+
Default: 50.
428+
429+
Returns:
430+
Papers: Papers object.
431+
"""
432+
params = self.__params(page, items_per_page)
433+
return self.__page(
434+
self.http.get(f"/authors/{author_id}/papers/", params=params),
435+
Papers,
436+
)
437+
246438
@handler
247439
def conference_list(
248440
self,

paperswithcode/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
"Papers",
55
"Repository",
66
"Repositories",
7+
"PaperRepo",
8+
"PaperRepos",
9+
"Author",
10+
"Authors",
711
"Conference",
812
"Conferences",
913
"Proceeding",
@@ -43,6 +47,8 @@
4347
from paperswithcode.models.page import Page
4448
from paperswithcode.models.paper import Paper, Papers
4549
from paperswithcode.models.repository import Repository, Repositories
50+
from paperswithcode.models.paper_repo import PaperRepo, PaperRepos
51+
from paperswithcode.models.author import Author, Authors
4652
from paperswithcode.models.conference import (
4753
Conference,
4854
Conferences,

paperswithcode/models/author.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import List
2+
3+
from tea_client.models import TeaClientModel
4+
5+
from paperswithcode.models.page import Page
6+
7+
8+
class Author(TeaClientModel):
9+
"""Author object.
10+
11+
Attributes:
12+
id (str): Author ID.
13+
full_name (str, optional): Author full name.
14+
"""
15+
16+
id: str
17+
full_name: str
18+
19+
20+
class Authors(Page):
21+
"""Object representing a paginated page of authors.
22+
23+
Attributes:
24+
count (int): Number of elements matching the query.
25+
next_page (int, optional): Number of the next page.
26+
previous_page (int, optional): Number of the previous page.
27+
results (List[Author]): List of authors on this page.
28+
"""
29+
30+
results: List[Author]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import List, Optional
2+
3+
from tea_client.models import TeaClientModel
4+
5+
from paperswithcode.models.page import Page
6+
from paperswithcode.models.paper import Paper
7+
from paperswithcode.models.repository import Repository
8+
9+
10+
class PaperRepo(TeaClientModel):
11+
"""Paper <-> Repository object.
12+
13+
Attributes:
14+
paper (Paper): Paper objects.
15+
repository (Repository, optional): Repository object.
16+
"""
17+
18+
paper: Paper
19+
repository: Optional[Repository]
20+
21+
22+
class PaperRepos(Page):
23+
"""Object representing a paginated page of paper<->repos.
24+
25+
Attributes:
26+
count (int): Number of elements matching the query.
27+
next_page (int, optional): Number of the next page.
28+
previous_page (int, optional): Number of the previous page.
29+
results (List[PaperRepo]): List of paper<->repos on this page.
30+
"""
31+
32+
results: List[PaperRepo]

paperswithcode/models/repository.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from tea_client.models import TeaClientModel
44

@@ -10,18 +10,23 @@ class Repository(TeaClientModel):
1010
1111
Attributes:
1212
url (str): URL of the repository.
13-
is_official (bool): Is this an official implementation of the paper.
13+
owner (str): Repository owner.
14+
name (str): Repository name.
1415
description (str): Repository description.
1516
stars (int): Number of repository stars.
1617
framework (str): Implementation framework (TensorFlow, PyTorch, MXNet,
17-
Torch, Jax, Caffee2...)
18+
Torch, Jax, Caffee2...).
19+
is_official (bool): Is this an official implementation of the paper.
20+
Available only when listing repositories for a specific paper.
1821
"""
1922

2023
url: str
21-
is_official: bool
24+
owner: str
25+
name: str
2226
description: str
2327
stars: int
2428
framework: str
29+
is_official: Optional[bool]
2530

2631

2732
class Repositories(Page):

paperswithcode/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ def __repr__(self):
1616
)
1717

1818

19-
version = Version(0, 2, 2)
19+
version = Version(0, 3, 0)
2020
__version__ = str(version)

0 commit comments

Comments
 (0)