In [39]:
import geoopt
import torch

class Distance2PoincareHyperplanes(torch.nn.Module):
    n = 0

    def __init__(
        self,
        plane_shape: int,
        num_planes: int,
        signed=True,
        squared=False,
        *,
        ball,
        std=1.0,
    ):
        super().__init__()
        self.signed = signed
        self.squared = squared
        # Do not forget to save Manifold instance to the Module
        self.ball = ball
        self.plane_shape = geoopt.utils.size2shape(plane_shape)
        self.num_planes = num_planes

        # In a layer we create Manifold Parameters in the same way we do it for
        # regular pytorch Parameters, there is no difference. But geoopt optimizer
        # will recognize the manifold and adjust to it
        self.points = geoopt.ManifoldParameter(
            torch.empty(num_planes, plane_shape), manifold=self.ball
        )
        self.std = std
        # following best practives, a separate method to reset parameters
        self.reset_parameters()

    def forward(self, input):
        input_p = input.unsqueeze(-self.n - 1)
        points = self.points.permute(1, 0)
        points = points.view(points.shape + (1,) * self.n)

        distance = self.ball.dist2plane(
            x=input_p, p=points, a=points, signed=self.signed, dim=-self.n - 2
        )
        if self.squared and self.signed:
            sign = distance.sign()
            distance = distance ** 2 * sign
        elif self.squared:
            distance = distance ** 2
        return distance

    def extra_repr(self):
        return (
            "plane_shape={plane_shape}, "
            "num_planes={num_planes}, "
            .format(**self.__dict__)
        )

    @torch.no_grad()
    def reset_parameters(self):
        direction = torch.randn_like(self.points)
        direction /= direction.norm(dim=-1, keepdim=True)
        distance = torch.empty_like(self.points[..., 0]).normal_(std=self.std)
        self.points.set_(self.ball.expmap0(direction * distance.unsqueeze(-1)))

In [8]:
optim = geoopt.optim.RiemannianAdam(classifier.parameters(), lr=1e-3)

In [22]:
import pandas as pd

embeds = pd.read_csv(
    "/home/phil/phylosig/emp/data/128d_hyperbolic_mixture_embeddings.csv",
    index_col=0
)
embeds

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,118,119,120,121,122,123,124,125,126,127
1001.SKM3,0.000488,-0.000606,0.000227,-0.000091,-0.000911,0.000063,-0.001924,0.001461,0.000541,-0.001521,...,-0.002705,-0.001286,0.001113,-0.001782,0.001084,0.000337,0.001531,0.000818,-0.000121,0.000782
1001.SKD6,0.000544,-0.000694,0.000196,-0.000117,-0.001020,0.000085,-0.002015,0.001409,0.000571,-0.001572,...,-0.002741,-0.001680,0.001006,-0.001735,0.001209,0.000319,0.001423,0.000808,-0.000156,0.000952
1001.SKM1,0.000549,-0.000660,0.000233,-0.000021,-0.000951,0.000017,-0.001953,0.001447,0.000538,-0.001555,...,-0.002895,-0.001981,0.001002,-0.001747,0.001140,0.000348,0.001446,0.000816,-0.000140,0.000766
1001.SKM2,0.000527,-0.000719,0.000238,-0.000064,-0.001034,0.000006,-0.001905,0.001430,0.000593,-0.001566,...,-0.002740,-0.001334,0.001113,-0.001690,0.001219,0.000332,0.001395,0.000833,-0.000189,0.000795
1001.SKB3,0.000542,-0.000658,0.000167,-0.000081,-0.000906,-0.000105,-0.001987,0.001494,0.000549,-0.001506,...,-0.002676,-0.000774,0.001054,-0.001706,0.000930,0.000254,0.001312,0.000799,-0.000113,0.001010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
990.KA3U.C.05,0.000549,0.000349,-0.000401,0.000644,-0.001443,0.001136,-0.000814,0.001491,0.001485,-0.001314,...,-0.000508,-0.007845,-0.001054,-0.001690,0.000838,0.001590,-0.000202,-0.000459,-0.000850,0.001860
990.KA3U.D.05,0.000489,-0.000466,0.000291,0.000246,-0.001183,-0.001305,-0.002921,0.002035,0.000431,-0.001303,...,-0.002709,0.000600,0.000098,-0.002060,0.002346,0.000084,0.001262,0.000394,-0.000241,0.000479
990.KA3U.C.12,0.000802,-0.000758,-0.000203,-0.000326,-0.001499,-0.001069,-0.002120,0.001840,0.000581,-0.001223,...,-0.006931,0.000853,0.001886,-0.000987,0.000207,-0.000273,0.000089,0.000389,0.000357,-0.000189
990.KA2F.B.11,0.000692,-0.001427,0.000083,-0.000969,-0.001901,-0.002096,-0.001985,0.001620,0.000590,-0.001327,...,-0.002420,0.000532,0.000715,-0.001973,0.002173,0.000739,0.001604,0.000770,-0.000821,-0.000219


In [34]:
emp_metadata = pd.read_table(
    "/home/phil/phylosig/emp/data/emp_qiime_mapping_release1_20170912.tsv",
    index_col=0
)
emp_metadata = emp_metadata.loc[embeds.index]
emp_metadata["host_subject_id"]

1001.SKM3             SKM3
1001.SKD6             SKD6
1001.SKM1             SKM1
1001.SKM2             SKM2
1001.SKB3             SKB3
                   ...    
990.KA3U.C.05    KA3U.C.05
990.KA3U.D.05    KA3U.D.05
990.KA3U.C.12    KA3U.C.12
990.KA2F.B.11    KA2F.B.11
990.KA3U.D.12    KA3U.D.12
Name: host_subject_id, Length: 27398, dtype: object

In [37]:
# Generate dataset 

# One-hot encode the labels
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

y = emp_metadata["env_material"]
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
# y = onehot_encoder.fit_transform(integer_encoded)
# y = torch.tensor(y, dtype=torch.long)
y = torch.Tensor(integer_encoded).long()

X = embeds.values
X = torch.tensor(X, dtype=torch.float)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

In [40]:
N_CLASSES = len(set(y_train.numpy().flatten()))

classifier = torch.nn.Sequential(
    Distance2PoincareHyperplanes(128, 10, ball=geoopt.PoincareBall()),
    torch.nn.Linear(10, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, N_CLASSES),
)
classifier

Sequential(
  (0): Distance2PoincareHyperplanes(
    plane_shape=(128,), num_planes=10, 
    (ball): PoincareBall manifold
  )
  (1): Linear(in_features=10, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=43, bias=True)
)

In [41]:
# Train!

from torch.nn.functional import cross_entropy

for epoch in range(100):
    optim.zero_grad()
    output = classifier(X_train)
    loss = cross_entropy(output, y_train.squeeze())
    loss.backward()
    optim.step()
    print(f"Epoch {epoch}: {loss.item()}")

Epoch 0: 3.9980099201202393
Epoch 1: 3.9980099201202393
Epoch 2: 3.9980099201202393
Epoch 3: 3.9980099201202393
Epoch 4: 3.9980099201202393
Epoch 5: 3.9980099201202393
Epoch 6: 3.9980099201202393
Epoch 7: 3.9980099201202393
Epoch 8: 3.9980099201202393
Epoch 9: 3.9980099201202393
Epoch 10: 3.9980099201202393
Epoch 11: 3.9980099201202393
Epoch 12: 3.9980099201202393
Epoch 13: 3.9980099201202393
Epoch 14: 3.9980099201202393
Epoch 15: 3.9980099201202393
Epoch 16: 3.9980099201202393
Epoch 17: 3.9980099201202393
Epoch 18: 3.9980099201202393
Epoch 19: 3.9980099201202393
Epoch 20: 3.9980099201202393
Epoch 21: 3.9980099201202393
Epoch 22: 3.9980099201202393
Epoch 23: 3.9980099201202393
Epoch 24: 3.9980099201202393
Epoch 25: 3.9980099201202393
Epoch 26: 3.9980099201202393
Epoch 27: 3.9980099201202393
Epoch 28: 3.9980099201202393
Epoch 29: 3.9980099201202393
Epoch 30: 3.9980099201202393
Epoch 31: 3.9980099201202393
Epoch 32: 3.9980099201202393
Epoch 33: 3.9980099201202393
Epoch 34: 3.998009920120