Skip to content

Commit

Permalink
Add automatically download checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
phamquiluan committed Aug 5, 2020
1 parent d3f9a58 commit f206331
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ImageNet training in PyTorch - Residual Attention Network

[![version](https://img.shields.io/badge/version-v0.0.1-blue)](https://github.com/phamquiluan/ResidualAttentionNetwork)
[![version](https://img.shields.io/badge/version-v0.1.0-blue)](https://github.com/phamquiluan/ResidualAttentionNetwork)
[![phamquiluan/ResidualAttentionNetwork](https://circleci.com/gh/phamquiluan/ResidualAttentionNetwork.svg?style=shield&circle-token=f96e4e1a66e86406f9a01512c52e1185b731ab0e)](https://app.circleci.com/pipelines/github/phamquiluan/ResidualAttentionNetwork)
[![phamquiluan/ResidualAttentionNetwork](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/phamquiluan/ResidualAttentionNetwork)

Expand All @@ -11,7 +11,7 @@ This implements training of [Residual Attention Network](https://arxiv.org/abs/1
# Install

```bash
pip install 'git+ssh://git@github.com/phamquiluan/ResidualAttentionNetwork.git@v0.0.1'
pip install 'git+ssh://git@github.com/phamquiluan/ResidualAttentionNetwork.git@v0.1.0'
```

# Quickstart
Expand Down
11 changes: 7 additions & 4 deletions resattnet/resattnet56.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import torch
import torch.nn as nn
from pytorchcv.model_provider import get_model as ptcv_get_model


def resattnet56(in_channels, num_classes, pretrained=True):
def resattnet56(in_channels=3, num_classes=1000, pretrained=True):
model = ptcv_get_model("resattnet56", pretrained=False)

if pretrained is True:
# load pretrained automatically
pass

state = torch.hub.load_state_dict_from_url(
"https://github.com/phamquiluan/ResidualAttentionNetwork/releases/download/v0.1.0/resattnet56.pth"
)
model.load_state_dict(state["state_dict"])
model.output = nn.Linear(2048, num_classes)
return model
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from setuptools import find_packages, setup

version = "0.0.1"
version = "0.1.0"
cwd = os.path.dirname(os.path.abspath(__file__))


Expand All @@ -29,6 +29,6 @@ def write_version_file():
"imgaug",
"tensorboard",
"sklearn",
"gputil"
"gputil",
],
)

0 comments on commit f206331

Please sign in to comment.