In [1]:
import string

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nltk
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


## Set Seed

In [2]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x30c93d250>

## Read Data

In [3]:
df = pd.read_csv(
    "{YOUR_DATA_PATH}/sentence_level_data.csv",
    index_col=[0],
    storage_options={
        "key": "REDACTED",
        "secret": "REDACTED",
    }
)

## Prep Data

In [4]:
def clean_text(s: str) -> str:
    """Clean the text.

    :param s: (str)
    :return: str
    """
    return s.lower().translate(s.maketrans("", "", string.punctuation))

df["cleaned_setence"] = df["sentence"].apply(lambda x: clean_text(x))

In [5]:
nltk.download("wordnet")

lemmer = WordNetLemmatizer()

[nltk_data] Downloading package wordnet to /Users/alee/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [6]:
def lemmatize_text(s: str, lemmer: WordNetLemmatizer) -> str:
    """Lemmatize the text.

    :param s: (str)
    :param stemmer: (PorterStemmer)
    :return: (str)
    """
    return " ".join([lemmer.lemmatize(word) for word in s.split()])

In [7]:
df["lemmatized_text"] = df["cleaned_setence"].apply(lambda x: lemmatize_text(x, lemmer))

In [8]:
tfidf = TfidfVectorizer()
x_tfidf = tfidf.fit_transform(df["lemmatized_text"])

In [9]:
class ChatGPTDataset(Dataset):
    def __init__(self, x_tfidf: list, y: int) -> None:
        self.x_tfidf = x_tfidf
        self.y = y
        
    
    def __len__(self) -> int:
        return len(self.x_tfidf)
    
    
    def __getitem__(self, index: int) -> tuple:
        return self.x_tfidf[index], self.y[index]

In [10]:
chatgpt_dataset = ChatGPTDataset(x_tfidf.toarray(), y=df["class"].tolist())
train_indices, test_indices = train_test_split(
    list(range(0, len(chatgpt_dataset))), test_size=0.2, random_state=RANDOM_SEED
)

In [11]:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

In [12]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    chatgpt_dataset,
    batch_size=batch_size,
    sampler=train_sampler
)
validation_loader = torch.utils.data.DataLoader(
    chatgpt_dataset,
    batch_size=batch_size,
    sampler=test_sampler
)

## Model

In [40]:
class DenseNetBayes(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.fc1 = nn.Linear(input_shape, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 128)
        self.pred = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x.to(torch.float)))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.pred(x).squeeze())
        
        return x

In [41]:
def modeller(x, y):
    fc1w_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc1.weight),
        scale=torch.ones_like(model.fc1.weight)
    )
    fc1b_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc1.bias),
        scale=torch.ones_like(model.fc1.bias)
    )

    fc2w_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc2.weight),
        scale=torch.ones_like(model.fc2.weight)
    )
    fc2b_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc2.bias),
        scale=torch.ones_like(model.fc2.bias)
    )

    fc3w_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc3.weight),
        scale=torch.ones_like(model.fc3.weight)
    )
    fc3b_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.fc3.bias),
        scale=torch.ones_like(model.fc3.bias)
    )
    
    predw_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.pred.weight),
        scale=torch.ones_like(model.pred.weight)
    )
    predb_prior = pyro.distributions.Normal(
        loc=torch.zeros_like(model.pred.bias),
        scale=torch.ones_like(model.pred.bias)
    )
    
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior,
        'fc2.weight': fc2w_prior,
        'fc2.bias': fc2b_prior,
        'fc3.weight': fc3w_prior,
        'fc3.bias': fc3b_prior,
        'pred.weight': predw_prior,
        'pred.bias': predb_prior
    }

    lifted_module = pyro.random_module("module", model, priors)
    lifted_reg_model = lifted_module()

    lhat = nn.LogSoftmax(dim=1)(lifted_reg_model(x))

    pyro.sample("obs", pyro.distributions.torch.Categorical(logits=lhat), obs=y)

In [42]:
def guide(x, y):
    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(model.fc1.weight)
    fc1w_sigma = torch.randn_like(model.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = nn.Softplus()(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = pyro.distributions.Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(model.fc1.bias)
    fc1b_sigma = torch.randn_like(model.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = nn.Softplus()(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = pyro.distributions.Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    # Second layer
    fc2w_mu = torch.randn_like(model.fc2.weight)
    fc2w_sigma = torch.randn_like(model.fc2.weight)
    fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
    fc2w_sigma_param = nn.Softplus()(pyro.param("fc2w_sigma", fc2w_sigma))
    fc2w_prior = pyro.distributions.Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
    # Second layer bias distribution priors
    fc2b_mu = torch.randn_like(model.fc2.bias)
    fc2b_sigma = torch.randn_like(model.fc2.bias)
    fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
    fc2b_sigma_param = nn.Softplus()(pyro.param("fc2b_sigma", fc2b_sigma))
    fc2b_prior = pyro.distributions.Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)
    # Third layer
    fc3w_mu = torch.randn_like(model.fc3.weight)
    fc3w_sigma = torch.randn_like(model.fc3.weight)
    fc3w_mu_param = pyro.param("fc3w_mu", fc3w_mu)
    fc3w_sigma_param = nn.Softplus()(pyro.param("fc3w_sigma", fc3w_sigma))
    fc3w_prior = pyro.distributions.Normal(loc=fc3w_mu_param, scale=fc3w_sigma_param)
    # Second layer bias distribution priors
    fc3b_mu = torch.randn_like(model.fc3.bias)
    fc3b_sigma = torch.randn_like(model.fc3.bias)
    fc3b_mu_param = pyro.param("fc3b_mu", fc3b_mu)
    fc3b_sigma_param = nn.Softplus()(pyro.param("fc3b_sigma", fc3b_sigma))
    fc3b_prior = pyro.distributions.Normal(loc=fc3b_mu_param, scale=fc3b_sigma_param)
    # Output layer weight distribution priors
    predw_mu = torch.randn_like(model.pred.weight)
    predw_sigma = torch.randn_like(model.pred.weight)
    predw_mu_param = pyro.param("outw_mu", predw_mu)
    predw_sigma_param = nn.Softplus()(pyro.param("outw_sigma", predw_sigma))
    predw_prior = pyro.distributions.Normal(loc=predw_mu_param, scale=predw_sigma_param).independent(1)
    # Output layer bias distribution priors
    predb_mu = torch.randn_like(model.pred.bias)
    predb_sigma = torch.randn_like(model.pred.bias)
    predb_mu_param = pyro.param("predb_mu", predb_mu)
    predb_sigma_param = nn.Softplus()(pyro.param("predb_sigma", predb_sigma))
    predb_prior = pyro.distributions.Normal(loc=predb_mu_param, scale=predb_sigma_param)
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior,
        'fc2.weight': fc2w_prior,
        'fc2.bias': fc2b_prior,
        'fc3.weight': fc3w_prior,
        'fc3.bias': fc3b_prior,
        'pred.weight': predw_prior,
        'pred.bias': predb_prior
    }
    
    lifted_module = pyro.random_module("module", model, priors)
    
    return lifted_module()

In [43]:
model = DenseNetBayes(input_shape=chatgpt_dataset.x_tfidf.shape[1])
adam_args = {"lr": 0.005}
optimizer = Adam(adam_args)
elbo = Trace_ELBO()

In [44]:
svi = SVI(modeller, guide, optimizer, loss=elbo)

In [45]:
num_iterations = 20
loss = 0

pyro.clear_param_store()

for j in tqdm(range(num_iterations)):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        loss += svi.step(data[0], data[1])
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch ", j, " Loss ", total_epoch_loss_train)

  x = F.log_softmax(self.pred(x).squeeze())
  5%|██████                                                                                                                   | 1/20 [01:36<30:27, 96.19s/it]

Epoch  0  Loss  166292.18103979586


 10%|████████████                                                                                                             | 2/20 [03:12<28:53, 96.33s/it]

Epoch  1  Loss  83316.8292290688


 15%|██████████████████▏                                                                                                      | 3/20 [04:50<27:30, 97.11s/it]

Epoch  2  Loss  50009.50883185845


 20%|████████████████████████▏                                                                                                | 4/20 [06:30<26:11, 98.24s/it]

Epoch  3  Loss  33838.90256885314


 25%|██████████████████████████████▎                                                                                          | 5/20 [08:08<24:33, 98.22s/it]

Epoch  4  Loss  24528.832419462615


 30%|████████████████████████████████████▎                                                                                    | 6/20 [09:46<22:51, 97.94s/it]

Epoch  5  Loss  18496.96980904455


 35%|██████████████████████████████████████████▎                                                                              | 7/20 [11:22<21:08, 97.55s/it]

Epoch  6  Loss  14249.23374544545


 40%|████████████████████████████████████████████████▍                                                                        | 8/20 [12:59<19:28, 97.36s/it]

Epoch  7  Loss  11236.658984638718


 45%|██████████████████████████████████████████████████████▍                                                                  | 9/20 [14:36<17:49, 97.19s/it]

Epoch  8  Loss  8986.921254552528


 50%|████████████████████████████████████████████████████████████                                                            | 10/20 [16:13<16:10, 97.08s/it]

Epoch  9  Loss  7286.71979393276


 55%|██████████████████████████████████████████████████████████████████                                                      | 11/20 [17:51<14:35, 97.24s/it]

Epoch  10  Loss  6027.378449516034


 60%|████████████████████████████████████████████████████████████████████████                                                | 12/20 [19:27<12:55, 96.93s/it]

Epoch  11  Loss  5050.498555700179


 65%|██████████████████████████████████████████████████████████████████████████████                                          | 13/20 [20:59<11:08, 95.48s/it]

Epoch  12  Loss  4277.3238145544765


 70%|████████████████████████████████████████████████████████████████████████████████████                                    | 14/20 [22:35<09:33, 95.67s/it]

Epoch  13  Loss  3668.331450677142


 75%|██████████████████████████████████████████████████████████████████████████████████████████                              | 15/20 [24:11<07:58, 95.79s/it]

Epoch  14  Loss  3189.620796243723


 80%|████████████████████████████████████████████████████████████████████████████████████████████████                        | 16/20 [25:48<06:23, 95.98s/it]

Epoch  15  Loss  2789.40002029868


 85%|██████████████████████████████████████████████████████████████████████████████████████████████████████                  | 17/20 [27:23<04:47, 95.90s/it]

Epoch  16  Loss  2466.610961653227


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████            | 18/20 [28:59<03:11, 95.85s/it]

Epoch  17  Loss  2204.1753551807246


 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████      | 19/20 [30:35<01:35, 95.86s/it]

Epoch  18  Loss  1986.2088803329823


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [32:11<00:00, 96.56s/it]

Epoch  19  Loss  1795.8730080969842





In [46]:
def predict(x):
    sampled_models = [guide(None, None) for _ in range(2)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return torch.argmax(mean, dim=1)


print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(validation_loader):
    text, labels = data
    predicted = predict(text)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print("accuracy: %d %%" % (100 * correct / total))

Prediction when network is forced to predict


  x = F.log_softmax(self.pred(x).squeeze())


accuracy: 50 %
