Skip to content

Commit

Permalink
not know
Browse files Browse the repository at this point in the history
  • Loading branch information
yifding committed Sep 15, 2020
2 parents d7acfdb + 1c54425 commit 6f425c1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 23 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ $ conda install python=3.7.4
```
$ git clone https://github.com/yifding/hetseq.git
$ cd /path/to/hetseq
$ pip install -r requirement.txt
$ pip install -r requirements.txt
$ pip install --editable .
```

Expand Down Expand Up @@ -72,7 +72,8 @@ this repository is MIT-licensed. It is created based on [fairseq](https://github

Please send us e-mail or leave comments on github if have any questions.

Copyright (c) 2020 Yifan Ding, Weninger Lab

Copyright (c) 2020 Yifan Ding and [Weninger Lab](https://www3.nd.edu/~tweninge/)



Expand Down
2 changes: 1 addition & 1 deletion docs/source/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Git Clone and Install Packages
$ git clone https://github.com/yifding/hetseq.git
$ cd /path/to/hetseq
$ pip install -r requirement.txt
$ pip install -r requirements.txt
$ pip install --editable .
Download BERT Processed File
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
cython == 0.29.13
numpy == 1.17.0
h5py == 2.10.0
torch == 1.2.0
torch == 1.6.0
tqdm == 4.36.1
boto3 == 1.9.244
chardet == 3.0.4
idna == 2.8
python-dateutil == 2.8.0
sphinx-rtd-theme == 0.5.0
sphinx == 3.2.1
boto3 == 1.9.244
torchvision == 0.7.0
28 changes: 10 additions & 18 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dataset(self, split):
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
a :class:`~torch.utils.data.Dataset` corresponding to *split*
"""
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_batch_iterator(
epoch (int, optional): the epoch to start the iterator from
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
~iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# For default fairseq task, return same iterator across epochs
Expand Down Expand Up @@ -127,12 +127,12 @@ def get_batch_iterator(

def build_model(self, args):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
Build the :class:`~torch.nn.Module` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
a :class:`~torch.nn.Module` instance
"""
raise NotImplementedError

Expand All @@ -142,10 +142,9 @@ def train_step(self, sample, model, optimizer, ignore_grad=False):
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
:class:`~torch.utils.data.Dataset`.
model (~torch.nn.Module): the model
optimizer (~optim._Optimizer): the optimizer
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
Expand Down Expand Up @@ -187,12 +186,9 @@ def update_step(self, num_updates):
class LanguageModelingTask(Task):
"""
Train a language model, currently support BERT.
Args:
args: parsed from command line
dictionary: the BPE dictionary for the input of the language model
Args:
args: parsed from command line
dictionary: the BPE dictionary for the input of the language model
"""

def __init__(self, args, dictionary):
Expand All @@ -215,7 +211,6 @@ def setup_task(cls, args, **kwargs):
"""
dictionary = cls.load_dictionary(cls, args.dict)

#return cls(args, dictionary, output_dictionary, targets=targets)
return cls(args, dictionary)

def build_model(self, args):
Expand Down Expand Up @@ -323,8 +318,6 @@ def __init__(self):
self.fc2 = nn.Linear(128, 10)

def forward(self, x, target, eval=False):
# print('shape', x.shape, target.shape)
# print(target)

x = self.conv1(x)
x = F.relu(x)
Expand All @@ -339,5 +332,4 @@ def forward(self, x, target, eval=False):
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
loss = F.nll_loss(output, target)
# return loss if not eval else output, loss
return loss

0 comments on commit 6f425c1

Please sign in to comment.