Skip to content

Commit

Permalink
added complex token -> patch transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
timojl committed Sep 27, 2022
1 parent b626809 commit ea54753
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
10 changes: 7 additions & 3 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ git clone https://github.com/juhongm999/hsnet.git

### Weights

The MIT license does not apply to these weights.
The MIT license does not apply to these weights.

We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
```
wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
unzip -d weights -j weights.zip
```

- [CLIPSeg-D64](https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth) (4.1MB, without CLIP weights)
- [CLIPSeg-D16](https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth) (1.1MB, without CLIP weights)

### Training and Evaluation

Expand Down
19 changes: 17 additions & 2 deletions models/clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class CLIPDensePredT(CLIPDenseBase):
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
extra_blocks=0, reduce_cond=None, fix_shift=False,
learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False):

super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
# device = 'cpu'
Expand Down Expand Up @@ -337,7 +337,22 @@ def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, r
# explicitly define transposed conv kernel size
trans_conv_ks = (trans_conv, trans_conv)

self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
if not complex_trans_conv:
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, self.trans_conv_ks, stride=self.trans_conv_ks)
else:
assert self.trans_conv_ks[0] == self.trans_conv_ks[1]

tp_kernels = (self.trans_conv_ks[0] // 4, self.trans_conv_ks[0] // 4)

self.trans_conv = nn.Sequential(
nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]),
nn.ReLU(),
nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]),
)

# self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)

assert len(self.extract_layers) == depth

Expand Down

0 comments on commit ea54753

Please sign in to comment.