In [1]:
import torch
import torch.nn as nn
import pandas as pd
#import math
import matplotlib.pyplot as plt
#from tqdm import tqdm

import numpy as np
import torch.nn.functional as F
#from torchvision.utils import make_grid, save_image
import torch.distributions as dist
import os

from torch.utils.data import Dataset, DataLoader

import nltk
import re
from nltk.tokenize import word_tokenize
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
import torch.optim as optim

In [2]:
df = pd.read_csv("IMDB_Dataset.csv")
print(df.head(), "\n")

                                              review sentiment
0  One of the other reviewers has mentioned that ...  positive
1  A wonderful little production. <br /><br />The...  positive
2  I thought this was a wonderful way to spend ti...  positive
3  Basically there's a family where a little boy ...  negative
4  Petter Mattei's "Love in the Time of Money" is...  positive 



In [3]:
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from collections import Counter

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')
  

#stop_words = set(stopwords.words('english'))
#stemmer = PorterStemmer()

df_tok = df[:10000].copy() 

df_tok

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Arthur\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Arthur\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Arthur\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive
1,A wonderful little production. <br /><br />The...,positive
2,I thought this was a wonderful way to spend ti...,positive
3,Basically there's a family where a little boy ...,negative
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive
...,...,...
9995,"Fun, entertaining movie about WWII German spy ...",positive
9996,Give me a break. How can anyone say that this ...,negative
9997,This movie is a bad movie. But after watching ...,negative
9998,This is a movie that was probably made to ente...,negative


In [4]:
# Tokenize, remove stopwords, and stem
df_tok['review'] = df_tok['review'].apply(lambda x: word_tokenize(x.lower()))
print("Tokenized")
#df_tok['review'] = df_tok['review'].apply(lambda x: [word for word in x if word not in stop_words])
#print("Removed stopwords")
#df_tok['review'] = df_tok['review'].apply(lambda x: [stemmer.stem(word) for word in x])
#print("Stemmed")

Tokenized


In [5]:
# Limit vocabulary size
vocab_size = 10000
all_words = [word for review in df_tok['review'] for word in review]
word_counts = Counter(all_words)
most_common_words = set([word for word, count in word_counts.most_common(vocab_size)])

df_tok['review'] = df_tok['review'].apply(lambda x: [word if word in most_common_words else '<UNK>' for word in x])

# Encode reviews
word2value = {word: idx for idx, word in enumerate(most_common_words, start=1)}
word2value['<UNK>'] = 0

df_enc = df_tok.copy()
df_enc['review'] = df_enc['review'].apply(lambda x: [word2value[word] for word in x])

# Convert to tensors and pad
review_tensors = [torch.tensor(encoded_review) for encoded_review in df_enc['review']]
padded = pad_sequence(review_tensors, batch_first=True, padding_value=0).narrow(1, 0, 100)

In [6]:
# Initialize lists to store values for the DataFrame
reviews = []
sentiment_values = []

# Iterate over each row of the tensor
for i in range(padded.size(0)):
    # Extract the row from the tensor
    row = padded[i]
    
    # Convert the tensor row to a list and append it to the 'reviews' list
    reviews.append(row.tolist())
    # Extract the corresponding value of 'sentiment' column from the other DataFrame
    sentiment_value = df_tok.iloc[i]['sentiment']
    # Append the value to the 'sentiment_values' list
    if sentiment_value == "positive" :
        sentiment_values.append(1)
    else :
        sentiment_values.append(0)

# Final dataframe
final_df = pd.DataFrame({'review': reviews, 'sentiment': sentiment_values})

print(final_df)


####### Train-test split ########

# Split the data into training and testing sets (80% training, 20% testing)
train_df, test_df = train_test_split(final_df, test_size=0.15, random_state=42)

print("Training set shape:", train_df.shape)
print("Testing set shape:", test_df.shape)


class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Extract features and labels for a single row
        features = torch.tensor(self.dataframe.iloc[idx]['review'], dtype=torch.float32) 
        label = torch.tensor(self.dataframe.iloc[idx]['sentiment'], dtype=torch.long)
        index = torch.tensor(self.dataframe.iloc[idx].name, dtype=torch.long)  
        return index, features, label

                                                 review  sentiment
0     [7759, 2175, 6046, 6132, 1739, 7074, 4869, 553...          1
1     [7288, 8050, 9881, 9776, 8880, 9549, 5673, 527...          1
2     [8994, 2373, 5143, 1683, 7288, 8050, 4542, 464...          1
3     [573, 4082, 8812, 7288, 3695, 551, 7288, 9881,...          0
4     [0, 4777, 8812, 7021, 9773, 3709, 6046, 9746, ...          1
...                                                 ...        ...
9995  [172, 5131, 6534, 698, 5250, 9507, 83, 2286, 1...          1
9996  [343, 2651, 7288, 2939, 8880, 2386, 382, 8527,...          0
9997  [5143, 698, 5816, 7288, 6915, 698, 8880, 84, 3...          0
9998  [5143, 5816, 7288, 698, 5538, 1683, 713, 6266,...          0
9999  [4518, 1376, 5250, 5694, 8880, 4385, 6046, 868...          1

[10000 rows x 2 columns]
Training set shape: (8500, 2)
Testing set shape: (1500, 2)


In [7]:
# Define batch size and number of epochs
batch_size = 200

# Create custom datasets
train_dataset = CustomDataset(train_df)
test_dataset = CustomDataset(test_df)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


data_train = iter(train_loader)

In [8]:
len(train_dataset)/batch_size

42.5

In [10]:
# Blocs du modèle

class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)
    
class Encoder(nn.Module):
    def __init__(self, input_size, z_dim, hidden_dim, bidirectional=True):
        super(Encoder, self).__init__()
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.hidden_factor = 2 if bidirectional else 1

        self.encoder = nn.LSTM(input_size, hidden_dim, bidirectional=True, batch_first=True)

        self.locs = nn.Linear(hidden_dim*self.hidden_factor, z_dim)
        self.scales = nn.Linear(hidden_dim*self.hidden_factor, z_dim)
    
    def forward(self, x):
        batch_size = x.size(0)
        # x should be of shape [batch_size, seq_len, input_size]
        if x.dim() == 1:
            x = x.unsqueeze(0).unsqueeze(0)  # shape [1, 1, input_size]
        elif x.dim() == 2:
            x = x.unsqueeze(2)  # shape [batch_size, 1, input_size]
        
        output, (hidden, c_n) = self.encoder(x)
        #print("pre hidden", h_n.shape)
        #hidden = h_n[-1]
        hidden = hidden.view(batch_size, self.hidden_dim*self.hidden_factor)

        locs = self.locs(hidden)
        scales = torch.clamp(F.softplus(self.scales(hidden)), min=1e-3)
        return locs, scales

class Decoder(nn.Module):
    def __init__(self, input_size, z_dim, hidden_dim, vocab_size, bidirectional=True):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.input_size = input_size
        self.bidirectional = bidirectional
        self.hidden_factor = 2 if bidirectional else 1

        # Linear layer to transform z_dim to hidden_dim
        self.linear = nn.Linear(z_dim, hidden_dim * self.hidden_factor)

        # LSTM to generate sequences
        self.lstm = nn.LSTM(input_size, hidden_dim, bidirectional=bidirectional, batch_first=True)
        self.output = nn.Linear(hidden_dim * self.hidden_factor, self.vocab_size+1)

    def forward(self, z, xs):
        batch_size = z.size(0)
        seq_len = xs.size(1)

        # Transform latent vector to initial hidden state
        hidden = self.linear(z).view(self.hidden_factor, batch_size, self.hidden_dim)

        if xs.is_cuda:
            hidden = hidden.cuda()

        c_0 = torch.zeros(self.hidden_factor, batch_size, self.hidden_dim)
        if xs.is_cuda:
            c_0 = c_0.cuda()

        decoder_hidden = (hidden, c_0)

        # Forward pass through LSTM
        outputs, _ = self.lstm(xs, decoder_hidden)

        outputs = self.output(outputs)
        # Apply log softmax to get log probabilities over the vocabulary
        logp = nn.functional.log_softmax(outputs, dim=-1)
        return logp

class Diagonal(nn.Module):
    def __init__(self, dim):
        super(Diagonal, self).__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.ones(self.dim))
        self.bias = nn.Parameter(torch.zeros(self.dim))

    def forward(self, x):
        return x * self.weight + self.bias

class Classifier(nn.Module):
    def __init__(self, dim):
        super(Classifier, self).__init__()
        self.dim = dim
        self.diag = Diagonal(self.dim)

    def forward(self, x):
        return self.diag(x)

class CondPrior(nn.Module):
    def __init__(self, dim):
        super(CondPrior, self).__init__()
        self.dim = dim
        self.diag_loc_true = nn.Parameter(torch.zeros(self.dim))
        self.diag_loc_false = nn.Parameter(torch.zeros(self.dim))
        self.diag_scale_true = nn.Parameter(torch.ones(self.dim))
        self.diag_scale_false = nn.Parameter(torch.ones(self.dim))

    def forward(self, x):
        x = x.unsqueeze(1) 
        loc = x * self.diag_loc_true + (1 - x) * self.diag_loc_false
        scale = x * self.diag_scale_true + (1 - x) * self.diag_scale_false
        return loc, torch.clamp(F.softplus(scale), min=1e-3)

In [24]:
# CCVAE model

def compute_kl(locs_q, scale_q, locs_p=None, scale_p=None):
    """
    Computes the KL(q||p)
    """
    if locs_p is None:
        locs_p = torch.zeros_like(locs_q)
    if scale_p is None:
        scale_p = torch.ones_like(scale_q)

    dist_q = dist.Normal(locs_q, scale_q)
    dist_p = dist.Normal(locs_p, scale_p)
    return dist.kl.kl_divergence(dist_q, dist_p).sum(dim=-1)

def img_log_likelihood(recon, xs):
        if xs.dim() == 1:
            xs = xs.unsqueeze(0).unsqueeze(0)  # shape [1, 1, input_size]
            recon = recon.unsqueeze(0).unsqueeze(0)
        elif xs.dim() == 2:
            xs = xs.unsqueeze(1)  # shape [batch_size, 1, input_size]
            recon = recon.unsqueeze(1)
        laplace_dist = dist.Laplace(recon, torch.ones_like(recon))
        log_prob = laplace_dist.log_prob(xs)
        return log_prob.sum(dim=(0, 1, 2))

def sentence_log_likelihood(recon, xs):
    vocab_size = recon.shape[-1] 
    
    target = xs.squeeze(-1)  # Shape: (batch_size, seq_len)
    target = target.long()

    # Reshape output and target for CrossEntropyLoss
    recon = recon.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
    target = target.view(-1)  # Shape: (batch_size * seq_len)

    # Define the loss function
    criterion = nn.CrossEntropyLoss()

    # Compute the loss
    loss = criterion(recon, target)
    return loss


class CCVAE(nn.Module):
    """
    CCVAE
    """
    def __init__(self, input_size, hidden_dim, vocab_size, z_dim, num_classes, use_cuda, prior_fn):
        super(CCVAE, self).__init__()
        self.z_dim = z_dim
        self.input_size = input_size
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.z_classify = num_classes
        self.z_style = z_dim - num_classes
        self.use_cuda = use_cuda
        self.num_classes = num_classes
        self.ones = torch.ones(1, self.z_style)
        self.zeros = torch.zeros(1, self.z_style)
        self.y_prior_params = prior_fn

        self.classifier = Classifier(self.num_classes)

        self.encoder = Encoder(self.input_size, self.z_dim, self.hidden_dim)
        self.decoder = Decoder(self.input_size, self.z_dim, self.hidden_dim, self.vocab_size)

        self.cond_prior = CondPrior(self.num_classes)

        if self.use_cuda:
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()
            self.y_prior_params = self.y_prior_params.cuda()
            self.cuda()

    def sup(self, x, y):
        y = y.float()
        bs = x.shape[0]
        #inference
        post_params = self.encoder(x)
        
        z = dist.Normal(*post_params).rsample()
        zc, zs = z.split([self.z_classify, self.z_style], 1)
        qyzc = dist.Bernoulli(logits=self.classifier(zc))
        log_qyzc = qyzc.log_prob(y.view(-1, 1)).sum(dim=-1)

        # compute kl
        locs_p_zc, scales_p_zc = self.cond_prior(y)
        prior_params = (torch.cat([locs_p_zc, self.zeros.expand(bs, -1)], dim=1), 
                        torch.cat([scales_p_zc, self.ones.expand(bs, -1)], dim=1))
        kl = compute_kl(*post_params, *prior_params)

        #compute log probs for x and y

        log_py = dist.Bernoulli(self.y_prior_params.expand(bs, -1)).log_prob(y.view(-1, 1)).sum(dim=-1)

        recon = self.decoder(z, x)

        log_qyx = self.classifier_loss(x, y)
        #recon = recon.transpose(0, 1)
        #recon = recon.reshape(bs, -1, 1)  # Transpose back to [batch_size, seq_len, hidden_dim]
        log_pxz = sentence_log_likelihood(recon, x)


        # we only want gradients wrt to params of qyz, so stop them propogating to qzx
        log_qyzc_ = dist.Bernoulli(logits=self.classifier(zc.detach())).log_prob(y.view(-1, 1)).sum(dim=-1)
        #print("log_qyzc_", log_qyzc_, "\nlog_qyx", log_qyx)
        w = torch.exp(log_qyzc_ - log_qyx)+ 1e-8
        elbo = (w * (-log_pxz - kl - log_qyzc) + log_py + log_qyx).mean()
        return -elbo

    def classifier_loss(self, x, y, k=100):
        """
        Computes the classifier loss.
        """
        zc, _ = dist.Normal(*self.encoder(x)).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1)
        logits = self.classifier(zc.view(-1, self.z_classify))
        d = dist.Bernoulli(logits=logits)
        y = y.unsqueeze(0).unsqueeze(-1)
        y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes)
        lqy_z = d.log_prob(y).view(k, x.shape[0], self.num_classes).sum(dim=-1)
        lqy_x = torch.logsumexp(lqy_z, dim=0) - np.log(k)
        return lqy_x

    #def reconstruct_img(self, x):
    #    return self.decoder(dist.Normal(*self.encoder(x)).rsample())

    def classifier_acc(self, x, y=None, k=1):
        zc, _ = dist.Normal(*self.encoder(x)).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1)
        logits = self.classifier(zc.view(-1, self.z_classify)).view(-1, self.num_classes)
        y = y.unsqueeze(0).unsqueeze(-1)
        y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes)
        preds = torch.round(torch.sigmoid(logits))
        acc = (preds.eq(y)).float().mean()
        return acc
    
    def accuracy(self, data_loader, *args, **kwargs):
        acc = 0.0
        for (_, x, y) in data_loader:
            if self.use_cuda:
                x, y = x.cuda(), y.cuda()
            batch_acc = self.classifier_acc(x, y)
            acc += batch_acc
        return acc / len(data_loader)

In [9]:
input_size = 1 
seq_len = 100
hidden_dim = 256
z_dim = 28
num_classes = 1
use_cuda = False
prior_fn = torch.tensor([[0.5]])
n_epochs = 10

In [10]:
num_batch = int(len(train_dataset)/batch_size)
num_batch

42

In [11]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

In [None]:
import gc

# Training
cc_vae = CCVAE(input_size, hidden_dim, vocab_size, z_dim, num_classes, use_cuda, prior_fn)

cc_vae.apply(init_weights)

optim = torch.optim.Adam(params=cc_vae.parameters(), lr=0.001)

for epoch in range(n_epochs):
    print("Epoch:", epoch)
    cc_vae.train()
    
    epoch_losses_sup = 0.0
    for batch_idx, (_, xs, ys) in enumerate(train_loader):
        print("  Batch:", batch_idx)
        xs = xs.view(xs.shape[0], seq_len, input_size)
        loss = cc_vae.sup(xs, ys)
        
        optim.zero_grad() 
        loss.backward()
        optim.step()

        epoch_losses_sup += loss.detach().item()
        print(f"Batch {batch_idx} Loss: {loss.detach().item():.3f}")
        gc.collect()
    
    avg_loss = epoch_losses_sup / len(train_loader)
    print(f"[Epoch {epoch+1:03d}] Sup Loss: {avg_loss:.3f}")

#cc_vae.eval() 
#test_acc = cc_vae.accuracy(test_loader)

#print(test_acc)

Epoch: 0
  Batch: 0
Batch 0 Loss: 20.146
  Batch: 1
Batch 1 Loss: 17.667
  Batch: 2
Batch 2 Loss: 15.381
  Batch: 3
Batch 3 Loss: 13.668
  Batch: 4


In [None]:
#Accuracy
cc_vae.eval() 

test_acc = cc_vae.accuracy(test_loader)

print(test_acc)


tensor(0.5019)


In [29]:
#Intervention

def intervention(model, x, y):
    post_params = model.encoder(x)
    z = dist.Normal(*post_params).rsample()
    zc, zs = z.split([model.z_classify, model.z_style], 1)
    locs_p_zc, scales_p_zc = model.cond_prior(1 - y)
    zc_new = dist.Normal(locs_p_zc, scales_p_zc).rsample()
    z_new = torch.cat((zc_new,zs),dim=1)
    recons = model.decoder(z_new,x)
    return recons



In [30]:
value2word = {v: k for k, v in word2value.items()}

In [38]:
def decode_recons(recons):
    recons = recons.round()
    r = torch.zeros(recons.shape[1])
    for i in range(recons.shape[1]):
        r[i] = torch.argmax(recons[0][i]) 
    decoded_review = [value2word.get(idx.item(), '<UNK>') for idx in r]
    decoded_text = '.'.join(decoded_review)
    return decoded_text

In [39]:
k = 0
C = 3
for (i,a ,b) in test_loader :
    while k < C:
        y = b[k].view(1)
        x = a[k].view(1,seq_len,input_size)
        recons = intervention(cc_vae, x, y)
        dec = decode_recons(recons)
        print("True review", df_tok.iloc[i[k].item()]['review'])
        print("Modified review", dec)
        k += 1

True review ['this', 'is', 'a', 'pretty', 'pointless', 'remake', '.', 'starting', 'with', 'the', 'opening', 'title', 'shots', 'of', 'the', 'original', 'was', 'a', 'real', 'mistake', 'as', 'it', 'reminds', 'the', 'viewer', 'of', 'what', 'a', 'great', 'little', 'period', 'piece', '<UNK>', 'that', 'was', '.', 'the', 'new', 'version', 'that', 'follows', 'is', 'an', 'exercise', 'in', '<UNK>', '<', 'br', '/', '>', '<', 'br', '/', '>', 'brian', '<UNK>', 'plays', 'a', '<UNK>', 'boy', "'", 'photographer', 'who', 'returns', 'to', 'a', '<UNK>', 'desert', 'town', 'populated', 'by', 'a', '<UNK>', 'of', '<UNK>', 'clichéd', 'stock', 'characters', ':', 'the', '<UNK>', 'sucking', '<UNK>', '<UNK>', ',', 'the', '<UNK>', 'old', '<UNK>', '<UNK>', ',', 'the', 'crippled', 'vet', 'and', 'his', 'asian', 'wife', ',', 'etc', '...', '<', 'br', '/', '>', '<', 'br', '/', '>', '<UNK>', "'s", 'character', 'witnesses', 'the', 'crashing', 'of', '<UNK>', "'", 'into', 'a', '<UNK>', 'and', 'shortly', 'after', 'strange', '

M2

In [12]:
class Encoder(nn.Module):
    def __init__(self, input_size, z_dim, hidden_dim, num_classes, bidirectional=True):
        super(Encoder, self).__init__()
        self.z_dim = z_dim
        self.input_size = input_size
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.hidden_factor = 2 if bidirectional else 1
        
        self.encoder = nn.LSTM(input_size, hidden_dim, bidirectional=True, batch_first=True)

        self.locs = nn.Linear(hidden_dim*self.hidden_factor, z_dim)
        self.scales = nn.Linear(hidden_dim*self.hidden_factor, z_dim)

    def forward(self, x, y):
        # x should be of shape [batch_size, seq_len, input_size]
        if x.dim() == 1:
            x = x.unsqueeze(0).unsqueeze(0)  # shape [1, 1, input_size]
        elif x.dim() == 2:
            x = x.unsqueeze(2)  # shape [batch_size, 1, input_size]

        y = y.view(batch_size,self.num_classes,self.input_size)
        x_cat_y = torch.cat((x,y), 1)

        output, (hidden, c_n) = self.encoder(x_cat_y)
       
        hidden = hidden.view(batch_size, self.hidden_dim*self.hidden_factor)

        locs = self.locs(hidden)
        scales = torch.clamp(F.softplus(self.scales(hidden)), min=1e-3)
        return locs, scales
    

class Decoder(nn.Module):
    def __init__(self, input_size, z_dim, num_classes, hidden_dim, vocab_size, seq_len,bidirectional=True):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.input_size = input_size
        self.bidirectional = bidirectional
        self.hidden_factor = 2 if bidirectional else 1

        # Linear layer to transform z_dim to hidden_dim
        self.linear = nn.Linear(z_dim + num_classes, hidden_dim * self.hidden_factor)

        # LSTM to generate sequences
        self.lstm = nn.LSTM(input_size, hidden_dim, bidirectional=True, batch_first=True)

        self.output = nn.Linear(hidden_dim * self.hidden_factor, self.vocab_size + 1)
        
    def forward(self, z, ysoft, xs):
        batch_size = z.size(0)
        z_cat_ysoft = torch.cat((z,ysoft), dim=1)
        hidden = self.linear(z_cat_ysoft).view(self.hidden_factor, batch_size, self.hidden_dim)

        #if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            #hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_dim)
        
        if xs.is_cuda:
            hidden = hidden.cuda()

        c_0 = torch.zeros(self.hidden_factor, batch_size, self.hidden_dim)  # Shape: [hidden_factor, batch_size, hidden_dim]
        if xs.is_cuda:
            c_0 = c_0.cuda()
        
        decoder_hidden = (hidden, c_0)
        outputs, _ = self.lstm(xs, decoder_hidden)

        outputs = self.output(outputs)

        logp = nn.functional.log_softmax(outputs, dim=-1)

        return logp
    
class Classifier(nn.Module):
    def __init__(self, input_size, hidden_dim, num_classes):
        super(Classifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim,num_classes)
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, x):
        # x should be of shape [batch_size, seq_len, input_size]
        if x.dim() == 1:
            x = x.unsqueeze(0).unsqueeze(0)  # shape [1, 1, input_size]
        elif x.dim() == 2:
            x = x.unsqueeze(2)  # shape [batch_size, 1, input_size]
            
        _, (h,c) = self.lstm(x)
        logits = self.linear(h[-1])
        #preds = self.softmax(logits)
        preds = torch.sigmoid(logits)
        return logits, preds

In [13]:
def sentence_log_likelihood(recon, xs):
    vocab_size = recon.shape[-1] 
    
    target = xs.squeeze(-1)  # Shape: (batch_size, seq_len)
    target = target.long()

    # Reshape output and target for CrossEntropyLoss
    recon = recon.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
    target = target.view(-1)  # Shape: (batch_size * seq_len)

    # Define the loss function
    criterion = nn.CrossEntropyLoss()

    # Compute the loss
    loss = criterion(recon, target)
    return loss

class m2_model(nn.Module):
    def __init__(self,input_size, z_dim, num_classes, hidden_dim, seq_len, vocab_size, batch_size,use_cuda):
        super(m2_model, self).__init__()
        self.use_cuda = use_cuda
        self.input_size = input_size
        self.z_dim = z_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.encoder = Encoder(self.input_size, self.z_dim, self.hidden_dim, num_classes)
        self.decoder = Decoder(self.input_size, self.z_dim, self.num_classes, self.hidden_dim, self.vocab_size, self.seq_len)
        self.classifier = Classifier(self.input_size, self.hidden_dim, self.num_classes)

        self.cat_loss = nn.BCEWithLogitsLoss(reduce=False)
    
    def sup(self, x, y):
        logits, preds = self.classifier(x)
        locs_z, scales_z = self.encoder(x,y)
        z = dist.Normal(locs_z, scales_z).rsample()
        recons = self.decoder(z,preds,x)
        
        llk = sentence_log_likelihood(recons,x)
        
        cat = self.cat_loss(logits.squeeze(1),y.float())
        kld_norm = torch.sum(0.5 * ( locs_z**2 + scales_z - 1 - torch.log(scales_z)),-1)
        
        loss = (llk + cat + kld_norm).mean()

        return loss, llk, cat, kld_norm
    
    def classifier_acc(self, x, y=None, k=1):
        _, preds = self.classifier(x)
        preds = torch.round(preds)
        acc = (preds.eq(y)).float().mean()
        return acc
    
    def accuracy(self, data_loader, *args, **kwargs):
        acc = 0.0
        for (_, x, y) in data_loader:
            if self.use_cuda:
                x, y = x.cuda(), y.cuda()
            batch_acc = self.classifier_acc(x, y)
            acc += batch_acc
        return acc / len(data_loader)

In [19]:
import gc

# Training
m2 = m2_model(input_size, z_dim, num_classes, hidden_dim, seq_len, vocab_size, batch_size, use_cuda)

m2.apply(init_weights)

optim = torch.optim.Adam(params=m2.parameters(), lr=0.001)

for epoch in range(n_epochs):
    print("Epoch:", epoch)
    m2.train()
    
    epoch_losses_sup = 0.0
    #for i in tqdm(range(num_batch)):
    for batch_idx, (_, xs, ys) in enumerate(train_loader):
        print("  Batch:", batch_idx)
        #xs, ys = next(train_loader)
        xs = xs.view(batch_size, seq_len, input_size)
        loss, llk, cat, kld_norm = m2.sup(xs, ys)

        #print("llk", llk.shape, llk.mean())
        #print("cat", cat.shape, cat.mean())
        #print("kld_norm", kld_norm.shape, kld_norm.mean())


        optim.zero_grad() 
        loss.backward()
        optim.step()

        epoch_losses_sup += loss.detach().item()
        print(f"Batch {batch_idx} Loss: {loss.detach().item():.3f}")

        # Debugging: Print gradient norms
        total_norm = 0
        for p in m2.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        print(f"Gradient norm: {total_norm:.3f}")

        gc.collect()
    
    avg_loss = epoch_losses_sup / len(train_loader)
    print(f"[Epoch {epoch+1:03d}] Sup Loss: {avg_loss:.3f}")



Epoch: 0
  Batch: 0
Batch 0 Loss: 15.509
Gradient norm: 69.520
  Batch: 1
Batch 1 Loss: 13.226
Gradient norm: 141.523
  Batch: 2


KeyboardInterrupt: 

In [15]:
#Accuracy
m2.eval() 

test_acc = m2.accuracy(test_loader)

print(test_acc)

tensor(0.4939)
