Skip to content

Commit

Permalink
renamed
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiaskatch committed Feb 21, 2024
1 parent b00a32e commit af98340
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Other requirements:

## Usage
We provide 2 main modules:
- ### [gated_linear_rnn.py](gated_linear_rnn/gated_linear_rnn.py)
- ### [gated_linear_rnn.py](flax_gated_linear_rnn/gated_linear_rnn.py)
A causal time mixing sequence model which can be used as a drop-in replacement for causal multi-head-attention.
Usage:
```
Expand Down Expand Up @@ -69,7 +69,7 @@ We provide 2 main modules:
- **Tied Input & Forget gate** (`use_tied_gates=True`) Ties the input and forget gate through the relation `forget_gate = 1-input_gate`.


- ## [gated_linear_rnn_lm.py](gated_linear_rnn/language_models/gated_linear_rnn_lm.py)
- ## [gated_linear_rnn_lm.py](flax_gated_linear_rnn/language_models/gated_linear_rnn_lm.py)
A GatedLinearRNN-based language model.
```
import jax
Expand Down
3 changes: 3 additions & 0 deletions flax_gated_linear_rnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gated_linear_rnn import GatedLinearRNN
from .language_models import GatedLinearRNNLM

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flax import linen as nn
import jax.numpy as jnp
from gated_linear_rnn.base_models.channel_mixing import ChannelMixing
from flax_gated_linear_rnn.base_models.channel_mixing import ChannelMixing

class SequenceModel(nn.Module):
n_layer: int
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gated_linear_rnn.base_models.sequence_model import SequenceModel
from gated_linear_rnn.base_models.time_mixing import CausalTimeMixing
from flax_gated_linear_rnn.base_models.sequence_model import SequenceModel
from flax_gated_linear_rnn.base_models.time_mixing import CausalTimeMixing
from typing import Optional, Callable
from gated_linear_rnn.gated_linear_rnn import GatedLinearRNN
from flax_gated_linear_rnn.gated_linear_rnn import GatedLinearRNN
from flax import linen as nn

class GatedLinearRNNLM(SequenceModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Callable
from gated_linear_rnn.base_models.sequence_model import SequenceModel
from gated_linear_rnn.base_models.time_mixing import CausalTimeMixing
from gated_linear_rnn.attention import MultiHeadSelfAttention
from flax_gated_linear_rnn.base_models.sequence_model import SequenceModel
from flax_gated_linear_rnn.base_models.time_mixing import CausalTimeMixing
from flax_gated_linear_rnn.attention import MultiHeadSelfAttention


class TransformerLM(SequenceModel):
Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions gated_linear_rnn/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import ast
from datetime import datetime
import wandb
from gated_linear_rnn import GateLoopLM, GateLoopText2SpeechModel
from gated_linear_rnn.language_models import TransformerLM
from flax_gated_linear_rnn import GateLoopLM, GateLoopText2SpeechModel
from flax_gated_linear_rnn.language_models import TransformerLM



Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='flax-gated-linear-rnn',
version='1.0.0',
version='1.0.1',
author='Tobias Katsch',
author_email='tobias.katsch42@gmail.com',
packages=find_packages(),
Expand Down
2 changes: 1 addition & 1 deletion utils/speech_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchaudio
import numpy as np
import os
from gated_linear_rnn import GateLoopLM, GateLoopText2SpeechModel
from flax_gated_linear_rnn import GateLoopLM, GateLoopText2SpeechModel


vocab = {
Expand Down
2 changes: 1 addition & 1 deletion utils/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from flax.training import train_state
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, List
from gated_linear_rnn.gated_linear_rnn import *
from flax_gated_linear_rnn.gated_linear_rnn import *


def get_home_directory():
Expand Down

0 comments on commit af98340

Please sign in to comment.