In [14]:
import numpy as np
import seaborn as sns
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import pairwise_distance, relu, softmax
from sklearn.metrics import roc_auc_score
from copy import deepcopy

In [None]:
class TripletDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        anchor_row = self.df.iloc[idx]
        anchor = torch.tensor(anchor_row['web'], dtype=torch.float32)
        label = anchor_row['strong-class']
        positive_idx = ((self.df['strong-class'] == label) & (self.df.index != idx))
        positive_row = self.df.loc[positive_idx].sample(n=1)
        positive = torch.tensor(positive_row['web'].values[0], dtype=torch.float32)
        negative_idx = self.df['weak-class'] != label
        negative_row = self.df.loc[negative_idx].sample(n=1)
        negative = torch.tensor(negative_row['web'].values[0], dtype=torch.float32)
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
        return anchor, positive, negative

class TripletDataModule(pl.LightningDataModule):
    def __init__(self, df, train_size=0.8, val_size=0.1, test_size=0.1, batch_size=32, transform = None):
        super().__init__()
        self.df = df
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        self.batch_size = batch_size
        self.transform = transform

    def setup(self, stage=None):
        n = self.df['strong-class'].max()
        train_end = int(self.train_size * n)
        val_end = train_end + int(self.val_size * n)
        indices = np.random.permutation(n)
        train_indices = indices[:train_end]
        val_indices = indices[train_end:val_end]
        test_indices = indices[val_end:]
        self.train_dataset = TripletDataset(self.df[self.df['strong-class'].isin(train_indices)], transform=self.transform)
        self.val_dataset = TripletDataset(self.df[self.df['strong-class'].isin(val_indices)], transform=self.transform)
        self.test_dataset = TripletDataset(self.df[self.df['strong-class'].isin(test_indices)], transform=self.transform)


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [None]:
dataloader = TripletDataModule(df)
dataloader.setup()

In [None]:
class SiameseNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, learning_rate = 1e-3, loss_margin = 1):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size1)
        self.fc2 = torch.nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = torch.nn.Linear(hidden_size2, output_size)
        self.learning_rate = learning_rate
        self.loss = torch.nn.TripletMarginLoss(loss_margin)

    def forward(self, x):
        x = torch.flatten(x,start_dim=1)
        x = self.fc1(x)
        x = relu(x)
        x = self.fc2(x)
        x = relu(x)
        x = self.fc3(x)
        x = softmax(x)
        return x

    def shared_step(self, anchor, positive, negative):
        anchor_embedding = self.forward(anchor)
        positive_embedding = self.forward(positive)
        negative_embedding = self.forward(negative)
        return anchor_embedding, positive_embedding, negative_embedding

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_embedding, positive_embedding, negative_embedding = self.shared_step(anchor, positive, negative)
        loss = self.loss(anchor_embedding, positive_embedding, negative_embedding)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_embedding, positive_embedding, negative_embedding = self.shared_step(anchor, positive, negative)
        loss = self.loss(anchor_embedding, positive_embedding, negative_embedding)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def test_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_embedding, positive_embedding, negative_embedding = self.shared_step(anchor, positive, negative)
        distance_positive = pairwise_distance(anchor_embedding, positive_embedding)
        distance_negative = pairwise_distance(anchor_embedding, negative_embedding)
        return {
            'distance_positive': distance_positive,
            'distance_negative': distance_negative,
        }

    def test_epoch_end(self, outputs):
        distances_positive = torch.cat([o['distance_positive'] for o in outputs])
        distances_negative = torch.cat([o['distance_negative'] for o in outputs])
        y_pred = (distances_positive < distances_negative).to(torch.float32)
        accuracy = y_pred.float().mean()
        self.log('test_accuracy', accuracy, prog_bar=True)


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
first_net = SiameseNet(6,200,200,50)
trainer = pl.Trainer(max_epochs=20)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(first_net,dataloader)

  rank_zero_deprecation(

  | Name | Type              | Params
-------------------------------------------
0 | fc1  | Linear            | 1.4 K 
1 | fc2  | Linear            | 40.2 K
2 | fc3  | Linear            | 10.1 K
3 | loss | TripletMarginLoss | 0     
-------------------------------------------
51.6 K    Trainable params
0         Non-trainable params
51.6 K    Total params
0.207     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  x = softmax(x)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
trainer.test(first_net,dataloader)

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  x = softmax(x)


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.644481897354126}
--------------------------------------------------------------------------------


[{'test_accuracy': 0.644481897354126}]