Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mo_gaal Tensorflow -> Pytorch #577

Merged
merged 19 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/mo_gaal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# Assuming clf is an instance of your model
probabilities, confidence = clf.predict_proba(X_test, return_confidence=True)
# print("Probabilities shape:", probabilities.shape)
# print("Confidence shape:", confidence.shape)

# evaluate and print the results
print("\nOn Training Data:")
evaluate_print(clf_name, y_train, y_train_scores)
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)

108 changes: 41 additions & 67 deletions pyod/models/gaal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,75 @@
Part of the codes are adapted from
https://github.com/leibinghe/GAAL-based-outlier-detection
"""
# Author: Winston Li <jk_zhengli@hotmail.com>
# License: BSD 2 clause

from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import tensorflow


# Old function, deprecat this in the future
def _get_tensorflow_version(): # pragma: no cover
""" Utility function to decide the version of tensorflow, which will
affect how to import keras models.

Returns
-------
tensorflow version : int

def create_discriminator(latent_size, data_size):
"""

tf_version = str(tensorflow.__version__)
if int(tf_version.split(".")[0]) != 1 and int(
tf_version.split(".")[0]) != 2:
raise ValueError("tensorflow version error")

return int(tf_version.split(".")[0]) * 100 + int(tf_version.split(".")[1])


# if tensorflow 2, import from tf directly
if _get_tensorflow_version() <= 200:
import keras
from keras.layers import Input, Dense
from keras.models import Sequential, Model
else:
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Sequential, Model


# TODO: create a base class for so_gaal and mo_gaal
def create_discriminator(latent_size, data_size): # pragma: no cover
"""Create the discriminator of the GAN for a given latent size.
Create the discriminator of the GAN for a given latent size.

Parameters
----------
latent_size : int
The size of the latent space of the generator.

data_size : int
Size of the input data.

Returns
-------
D : Keras model() object
Returns a model() object.
discriminator : torch.nn.Module
A PyTorch model of the discriminator.
"""

dis = Sequential()
dis.add(Dense(int(math.ceil(math.sqrt(data_size))),
input_dim=latent_size, activation='relu',
kernel_initializer=keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='normal',
seed=None)))
dis.add(Dense(1, activation='sigmoid',
kernel_initializer=keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='normal',
seed=None)))
data = Input(shape=(latent_size,))
fake = dis(data)
return Model(data, fake)
class Discriminator(nn.Module):
def __init__(self, latent_size, data_size):
super(Discriminator, self).__init__()
self.layer1 = nn.Linear(latent_size, math.ceil(math.sqrt(data_size)))
self.layer2 = nn.Linear(math.ceil(math.sqrt(data_size)), 1)
nn.init.kaiming_normal_(self.layer1.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_normal_(self.layer2.weight, mode='fan_in', nonlinearity='sigmoid')

def forward(self, x):
x = F.relu(self.layer1(x))
x = torch.sigmoid(self.layer2(x))
return x

def create_generator(latent_size): # pragma: no cover
"""Create the generator of the GAN for a given latent size.
return Discriminator(latent_size, data_size)


def create_generator(latent_size):
"""
Create the generator of the GAN for a given latent size.

Parameters
----------
latent_size : int
The size of the latent space of the generator
The size of the latent space of the generator.

Returns
-------
D : Keras model() object
Returns a model() object.
generator : torch.nn.Module
A PyTorch model of the generator.
"""

gen = Sequential()
gen.add(Dense(latent_size, input_dim=latent_size, activation='relu',
kernel_initializer=keras.initializers.Identity(
gain=1.0)))
gen.add(Dense(latent_size, activation='relu',
kernel_initializer=keras.initializers.Identity(
gain=1.0)))
latent = Input(shape=(latent_size,))
fake_data = gen(latent)
return Model(latent, fake_data)
class Generator(nn.Module):
def __init__(self, latent_size):
super(Generator, self).__init__()
self.layer1 = nn.Linear(latent_size, latent_size)
self.layer2 = nn.Linear(latent_size, latent_size)
nn.init.eye_(self.layer1.weight)
nn.init.eye_(self.layer2.weight)

def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return x

return Generator(latent_size)
Loading
Loading