In [None]:
import torch
import torch.distributions as D
import numpy as np
import math
from torchinfo import summary
import random

device = 'cuda'

seed = 1000
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)

### Dataset setup exchangeable

In [None]:
V = 1000 # vocab size
K = 5 # num of topics
N = 12000 # num of documents
M = 100 # num of words in each doc
alpha = 0.5

# The smaller the Dirichlet parameter, the less even is the simplex spread
dcht_V = D.dirichlet.Dirichlet(torch.zeros([V])+0.5) # dist over words
dcht_K = D.dirichlet.Dirichlet(torch.zeros([K])+alpha) # dist over topics

topic_vecs = dcht_V.sample([K])
print(topic_vecs.shape)

In [None]:
dataset = torch.zeros([N, M], dtype=torch.long)
topic_mixtures = torch.zeros([N, K], dtype=torch.float32)

for n in range(N):

    # draw topic proportion
    theta = torch.squeeze(dcht_K.sample([1]))
    topic_mixtures[n,:] = theta
    # draw topic assignment for all words in document n
    dist_topic = D.categorical.Categorical(theta)
    topic_assignments = dist_topic.sample([M])

    for m in range(M):

        k = topic_assignments[m]
        # draw word
        topic_vec = topic_vecs[k]
        word = D.categorical.Categorical(topic_vec).sample([1])
        word = torch.squeeze(word)
        # word = torch.squeeze(torch.nn.functional.one_hot(word, num_classes=V))
        dataset[n,m] = word

    if n % 1000 == 0:
        print(n)

In [None]:
np.savetxt('../data/dataset_005_N10000_V1000_ID1.csv', dataset.cpu().numpy())
np.savetxt('../data/topic_mixtures_005_N10000_V1000_ID1.csv', topic_mixtures.cpu().numpy())
np.savetxt('../data/topic_vecs_005_N10000_V1000_ID1.csv', topic_vecs.cpu().numpy())

In [None]:
dataset = torch.tensor(np.genfromtxt('../data/dataset_005_N10000_V1000_ID1.csv'), dtype=torch.int64)
topic_mixtures = torch.tensor(np.genfromtxt('../data/topic_mixtures_005_N10000_V1000_ID1.csv'))
topic_vecs = torch.tensor(np.genfromtxt('../data/topic_vecs_005_N10000_V1000_ID1.csv'))

In [None]:
val_idx = 10000
test_idx = 11000
len(dataset)

### LLM

In [None]:
START_IDX = 100
END_IDX = 101
PAD_IDX = 102

def generate_square_subsequent_mask(sz, device='cpu'):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt, device='cpu'):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = generate_square_subsequent_mask(src_seq_len)
    #src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == 999999)
    tgt_padding_mask = (tgt == 999999)
    return src_mask.to(torch.device(device)), tgt_mask.to(torch.device(device)), src_padding_mask.to(torch.device(device)), tgt_padding_mask.to(torch.device(device))

# src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(torch.zeros((3,10)), torch.zeros((3,11)), device=device)

In [None]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):

        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)

class TransformerModel(torch.nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5, use_pos = False):
        super().__init__()
        self.model_type = 'Transformer'
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = torch.nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = torch.nn.Linear(d_model, ntoken)
        self.use_pos = use_pos
        if self.use_pos:
            self.pos_encoder = PositionalEncoding(d_model, dropout)

    def forward(self, src, memory=None, src_mask=None):

        src = self.embedding(src) * np.sqrt(self.d_model)
        if self.use_pos:
            src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        self.doc_embd = output
        output = self.linear(output)
        return output

In [None]:
dataset = dataset.to(torch.device(device))
topic_mixtures = topic_mixtures.to(torch.device(device))

train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

def get_loss_tv(output, target):

    diff = torch.abs(output-target)

    return torch.max(diff, dim=1).values

def get_loss_l2(output, target):

    return torch.mean(torch.sum((output-target)**2, dim=1))

In [None]:
d_model = 128
model = TransformerModel(V, d_model, 8, d_model, 4, 0.1, True).to(torch.device(device))
summary(model, [(1,M-1), (1,1,d_model), (M-1,M-1)], dtypes=[torch.int32,torch.float32,torch.float32])

In [None]:
'Transformer training'
sz = M-1
mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
# mask = None

d_model = 128
model = TransformerModel(V, d_model, 8, d_model, 4, 0.1, True).to(torch.device(device))

criterion = torch.nn.CrossEntropyLoss()
lr = 0.0001  # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
batch_size = 16
min_val_loss = 1e6

for epoch in range(200):

    model.train()
    total_loss = 0.
    count = 0

    num_batches = len(train_dataset) // batch_size
    for i in range(0, len(train_dataset), batch_size):
        end_idx = min(i+batch_size, len(train_dataset))
        data = train_dataset[i:end_idx, :-1]
        target = train_dataset[i:end_idx, 1:]

        output = model(data, torch.zeros(data.shape[0], 1, d_model).to(torch.device(device)), mask)

        output = output[:,:,:]
        target = target[:,:]

        output_flat = torch.reshape(output,(-1,V))
        target_flat = torch.reshape(target, (-1,))
        loss = criterion(output_flat, target_flat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.shape[0]

        # if count % 10 == 0:
        #     print(loss)

        count += 1

    print('Total train loss', total_loss / len(train_dataset))

    model.eval()
    total_loss = 0.
    count = 0

    num_batches = len(val_dataset) // batch_size
    for i in range(0, len(val_dataset), batch_size):
        end_idx = min(i+batch_size, len(val_dataset))
        data = val_dataset[i:end_idx, :-1]
        target = val_dataset[i:end_idx,1:]

        output = model(data, torch.zeros(data.shape[0], 1, 128).to(torch.device(device)), mask)

        output = output[:,:,:]
        target = target[:,:]

        output_flat = torch.reshape(output,(-1,V))
        target_flat = torch.reshape(target, (-1,))
        loss = criterion(output_flat, target_flat)

        total_loss += loss.item() * data.shape[0]

        # if count % 10 == 0:
        #     print(loss)

        count += 1

    print('Total val loss', total_loss / len(val_dataset))

    if total_loss < min_val_loss:
        min_val_loss = total_loss
        torch.save(model.state_dict(), '../results/transformer_model_weights_1.pth')



In [None]:
model.load_state_dict(torch.load('../results/transformer_model_weights_1.pth'))

input = val_dataset[[0], :-1].to(torch.device(device))

output = model(input, torch.zeros(1, 1, 128).to(torch.device(device)), mask)

print(input)
#print(output)
print(torch.argmax(output, dim=2))

In [None]:
# lr = 0.0001 for real and 0.001 for fake

id = 1
dataset = torch.tensor(np.genfromtxt(f'data/dataset_005_N10000_V1000_ID{id}.csv'), dtype=torch.int64)
topic_mixtures = torch.tensor(np.genfromtxt(f'data/topic_mixtures_005_N10000_V1000_ID{id}.csv'))
topic_vecs = torch.tensor(np.genfromtxt(f'data/topic_vecs_005_N10000_V1000_ID{id}.csv'))

dataset = dataset.to(torch.device(device))
topic_mixtures = topic_mixtures.to(torch.device(device))

train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

def loss_mle(target_mixtures, q_params):
    samps = target_mixtures
    # Get q
    q = D.dirichlet.Dirichlet(q_params)
    logq = q.log_prob(samps)

    return torch.mean(-logq)


#train_mode = 'bayesian' # choose between 'classification', 'bayesian'
train_mode = 'classification'

if train_mode == 'bayesian':
    classifier = torch.nn.Sequential(
        torch.nn.Linear(d_model, 5),
        torch.nn.Softplus(),
    ).to(torch.device(device))

elif train_mode == 'classification':
    classifier = torch.nn.Sequential(
        torch.nn.Linear(d_model, 5)
    ).to(torch.device(device))


model.eval()
for param in model.parameters():
    param.requires_grad = False

criterion = torch.nn.CrossEntropyLoss(reduction='sum')
lr = 0.001  # learning rate
optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
batch_size = 64

token_num = -1

for epoch in range(300):

    classifier.train()
    total_loss = 0.
    count = 0

    num_batches = len(val_dataset) // batch_size
    for i in range(0, len(val_dataset), batch_size):
        end_idx = min(i+batch_size, len(val_dataset))
        data = val_dataset[i:end_idx, :-1]
        src = model.embedding(data) * np.sqrt(model.d_model)
        if model.use_pos:
            src = model.pos_encoder(src)
        if token_num == 0.5:
            outs = model.transformer_encoder(src, mask)
            embd = torch.mean(outs[:,:,:], dim=1)
        else:
            embd = model.transformer_encoder(src, mask)[:,token_num,:]


        target = val_mixtures[i:end_idx, :]

        output = classifier(embd)

        if train_mode == 'classification':
            loss = criterion(output, target)
        elif train_mode == 'bayesian':
            loss = loss_mle(target, output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # if count % 10 == 0:
        #     print(loss)
        #     # for param in classifier.parameters():
        #     #     print(torch.mean(param))

        count += 1

    print('Total train loss', total_loss/len(val_dataset))

    classifier.eval()
    total_loss = 0.
    loss_l2 = 0.
    loss_tv = 0.
    loss_ce = 0.
    count = 0
    pred_results = []

    num_batches = len(test_dataset) // batch_size
    for i in range(0, len(test_dataset), batch_size):
        end_idx = min(i+batch_size, len(test_dataset))
        data = test_dataset[i:end_idx, :-1]
        src = model.embedding(data) * np.sqrt(model.d_model)
        if model.use_pos:
            src = model.pos_encoder(src)
        if token_num == 0.5:
            outs = model.transformer_encoder(src, mask)
            embd = torch.mean(outs[:,:,:], dim=1)
        else:
            embd = model.transformer_encoder(src, mask)[:,token_num,:]

        target = test_mixtures[i:end_idx, :]

        output = classifier(embd)
        #output = torch.zeros_like(output)+0.2

        if train_mode == 'classification':
            loss = criterion(output, target)
        elif train_mode == 'bayesian':
            loss = loss_mle(target, output)
        total_loss += loss.item()

        true_class = torch.argmax(target, 1)
        pred_class = torch.argmax(output, 1)
        pred_result = true_class == pred_class
        pred_result = list(pred_result.cpu().numpy())
        pred_results.extend(pred_result)
        # if count % 10 == 0:
        #     print(loss)
        count += 1

        if train_mode == 'classification':
            loss_ce = total_loss
            loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()
            loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()
        elif train_mode == 'bayesian':
            q = D.dirichlet.Dirichlet(np.squeeze(output))
            samp = torch.mean(torch.squeeze(q.sample([10])), dim=0)
            loss_ce += -torch.sum(torch.sum(torch.multiply(torch.log(samp), target), dim=1)).item()
            loss_l2 += torch.sum(torch.sum((samp-target)**2, dim=1)).item()
            loss_tv += torch.sum(get_loss_tv(samp, target)).item()

    pred_results = np.array(pred_results)
    acc = np.mean(pred_results)

    print('Val loss CE', loss_ce/len(test_dataset))
    print('Val loss L2', loss_l2/len(test_dataset))
    print('Val loss TV', loss_tv/len(test_dataset))
    print('Accuracy', acc)

    # total variation distance

### BERT

In [None]:
import random
mask_idx = random.sample(list(range(99)), 15)
print(mask_idx)

In [None]:
train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

In [None]:
def get_loss_tv(output, target):

    diff = torch.abs(output-target)

    return torch.max(diff, dim=1).values

def get_loss_l2(output, target):

    return torch.mean(torch.sum((output-target)**2, dim=1))

In [None]:
'BERT'

import transformers

config = transformers.AutoConfig.from_pretrained(
    "prajjwal1/bert-tiny",
    vocab_size=V+3,
    n_ctx=100,
    bos_token_id=V,
    eos_token_id=V+1,
    mask_token_id=V+2
)
model = transformers.BertLMHeadModel(config).to(torch.device(device))
model_size = sum(t.numel() for t in model.parameters())
print(model_size)

criterion = torch.nn.CrossEntropyLoss()
lr = 0.0001  # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
batch_size = 16
min_val_loss = 1e6

for epoch in range(200):

    model.train()
    total_loss = 0.
    count = 0

    num_batches = len(train_dataset) // batch_size
    for i in range(0, len(train_dataset), batch_size):
        end_idx = min(i+batch_size, len(train_dataset))
        target = train_dataset[i:end_idx, :]
        input = torch.zeros_like(target)
        input[:,:] = target
        input[:,mask_idx] = V+2

        input = input.to(torch.device(device))
        target = target.to(torch.device(device))

        output = model(input).logits

        output = output[:,mask_idx,:]
        target = target[:,mask_idx]

        output_flat = output.view(-1, V+3)
        target_flat = torch.reshape(target, (-1,))
        loss = criterion(output_flat, target_flat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * input.shape[0]

        # if count % 10 == 0:
        #     print(loss)

        count += 1

    print('Total train loss', total_loss/len(train_dataset))

    model.eval()
    total_loss = 0.
    count = 0

    num_batches = len(val_dataset) // batch_size
    for i in range(0, len(val_dataset), batch_size):
        end_idx = min(i+batch_size, len(val_dataset))
        target = val_dataset[i:end_idx, :]
        input = torch.zeros_like(target)
        input[:,:] = target
        input[:,mask_idx] = V+2

        input = input.to(torch.device(device))
        target = target.to(torch.device(device))

        output = model(input).logits

        output = output[:,mask_idx,:]
        target = target[:,mask_idx]

        output_flat = output.view(-1, V+3)
        target_flat = torch.reshape(target, (-1,))
        loss = criterion(output_flat, target_flat)

        total_loss += loss.item() * input.shape[0]

        # if count/ % 10 == 0:
        #     print(loss)

        count += 1

    print('Total val loss', total_loss/len(val_dataset))

    if total_loss < min_val_loss:
        min_val_loss = total_loss
        torch.save(model.state_dict(), '../results/bert_model_weights.pth')

In [None]:
model.load_state_dict(torch.load('../results/bert_model_weights.pth'))

target = val_dataset[0:1, :]
input = torch.zeros_like(target)
input[:,:] = target
input[:,mask_idx] = V+2

input = input.to(torch.device(device))
target = target.to(torch.device(device))

output = model(input)

print(input)
print(target)
print(input[:,mask_idx])
print(target[:,mask_idx])
print(torch.argmax(output.logits, dim=2))
print(torch.argmax(output.logits, dim=2)[:,mask_idx])

In [None]:
'''
BERT probing
'''

id = 2
dataset = torch.tensor(np.genfromtxt(f'../data/dataset_005_N10000_V1000_ID{id}.csv'), dtype=torch.int64)
topic_mixtures = torch.tensor(np.genfromtxt(f'../data/topic_mixtures_005_N10000_V1000_ID{id}.csv'))
topic_vecs = torch.tensor(np.genfromtxt(f'../data/topic_vecs_005_N10000_V1000_ID{id}.csv'))

dataset = dataset.to(torch.device(device))
topic_mixtures = topic_mixtures.to(torch.device(device))

train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

#train_mode = 'bayesian' # choose between 'classification', 'bayesian'
train_mode = 'classification'
d_model = 128

if train_mode == 'bayesian':
    classifier = torch.nn.Sequential(
        torch.nn.Linear(d_model, 5),
        torch.nn.Softplus(),
    ).to(torch.device(device))

elif train_mode == 'classification':
    classifier = torch.nn.Sequential(
        torch.nn.Linear(d_model, 5)
    ).to(torch.device(device))

model.eval()
for param in model.parameters():
    param.requires_grad = False

criterion = torch.nn.CrossEntropyLoss(reduction='sum')
lr = 0.003  # learning rate
optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
batch_size = 64

token_num = 0.5

for epoch in range(300):

    classifier.train()
    total_loss = 0.
    count = 0

    num_batches = len(val_dataset) // batch_size
    for i in range(0, len(val_dataset), batch_size):
        end_idx = min(i+batch_size, len(val_dataset))
        data = val_dataset[i:end_idx, :]
        input = torch.zeros_like(data)
        input[:,:] = data
        input[:,mask_idx] = 102
        with torch.no_grad():
            if token_num != 0.5:
                embd = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1][:,token_num,:]
            else:
                outs = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1]
                embd = torch.mean(outs[:,:,:], dim=1)

        target = val_mixtures[i:end_idx, :].to(torch.device(device))

        output = classifier(embd)

        if train_mode == 'classification':
            loss = criterion(output, target)
        elif train_mode == 'bayesian':
            loss = loss_mle(target, output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # if count % 10 == 0:
        #     print(loss)
        #     # for param in classifier.parameters():
        #     #     print(torch.mean(param))

        count += 1

    print('Total train loss', total_loss/len(val_dataset))

    classifier.eval()
    total_loss = 0.
    loss_l2 = 0.
    loss_tv = 0.
    loss_ce = 0.
    count = 0
    pred_results = []

    num_batches = len(test_dataset) // batch_size
    for i in range(0, len(test_dataset), batch_size):
        end_idx = min(i+batch_size, len(test_dataset))
        data = test_dataset[i:end_idx, :]
        input = torch.zeros_like(data)
        input[:,:] = data
        input[:,mask_idx] = 102
        with torch.no_grad():
            if token_num != 0.5:
                embd = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1][:,token_num,:]
            else:
                outs = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1]
                embd = torch.mean(outs[:,:,:], dim=1)

        target = test_mixtures[i:end_idx, :].to(torch.device(device))

        output = classifier(embd)
        #output = torch.zeros_like(output)+0.2

        if train_mode == 'classification':
            loss = criterion(output, target)
        elif train_mode == 'bayesian':
            loss = loss_mle(target, output)
        total_loss += loss.item()

        true_class = torch.argmax(target, 1)
        pred_class = torch.argmax(output, 1)
        pred_result = true_class == pred_class
        pred_result = list(pred_result.cpu().numpy())
        pred_results.extend(pred_result)
        # if count % 10 == 0:
        #     print(loss)
        count += 1

        if train_mode == 'classification':
            loss_ce = total_loss
            loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()
            loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()
        elif train_mode == 'bayesian':
            q = D.dirichlet.Dirichlet(np.squeeze(output))
            samp = torch.mean(torch.squeeze(q.sample([10])), dim=0)
            loss_ce += -torch.sum(torch.sum(torch.multiply(torch.log(samp), target), dim=1)).item()
            loss_l2 += torch.sum(torch.sum((samp-target)**2, dim=1)).item()
            loss_tv += torch.sum(get_loss_tv(samp, target)).item()

    pred_results = np.array(pred_results)
    acc = np.mean(pred_results)

    print('Val loss CE', loss_ce/len(test_dataset))
    print('Val loss L2', loss_l2/len(test_dataset))
    print('Val loss TV', loss_tv/len(test_dataset))
    print('Accuracy', acc)

    # total variation distance

### LDA

In [None]:
import gensim
from gensim.models import LdaModel
from gensim import corpora

train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

dataset_lda = np.array(dataset.cpu().numpy(), dtype=str)
dictionary = corpora.Dictionary(dataset_lda)

for key in dictionary.token2id:
    dictionary.id2token[dictionary.token2id[key]] = key

lda_corpus = [dictionary.doc2bow(text) for text in dataset_lda]

In [None]:
lda = LdaModel(lda_corpus[:val_idx], num_topics=5, iterations=3000, passes=10)
print('LDA lower bound', lda.log_perplexity(lda_corpus[val_idx:test_idx]))

In [None]:
# ground truth generator
torch.argsort(topic_vecs, 1, descending=True)

In [None]:
print('-- Printing topics --')
for i in range(5):
    word_list = []
    pairs = lda.get_topic_terms(i, topn=20)
    for pair in pairs:
        word_list.append(dictionary.id2token[pair[0]])
    print(word_list)
print('---------------------')

# topic 3 of trained LDA is topic 0 of the generator
# topic 4 is topic 1
# topic 0 is topic 2
# topic 1 is topic 3
# topic 2 is topic 4

Here you would need to match ground truth topic ids to lda topic ids

In [None]:
trained_lda_mixtures = []
for i in range(len(dataset_lda)):
    m = lda.get_document_topics(lda_corpus[i], minimum_probability=1e-15)
    m = [tup[1] for tup in m]
    m_ = [0,0,0,0,0]
    m_[0] = m[3]
    m_[1] = m[2]
    m_[2] = m[0]
    m_[3] = m[1]
    m_[4] = m[4]
    trained_lda_mixtures.append(m_)

In [None]:
def get_loss_tv(output, target):

    diff = torch.abs(output-target)

    return torch.max(diff, dim=1).values

def get_loss_l2(output, target):

    return torch.mean(torch.sum((output-target)**2, dim=1))

output = torch.tensor(trained_lda_mixtures[test_idx:]).to(torch.device(device))
target = test_mixtures.to(torch.device(device))

loss_ce = -torch.mean(torch.sum(torch.multiply(torch.log(output), target), dim=1))
loss_l2 = torch.mean(torch.sum((output-target)**2, dim=1))
loss_tv = torch.mean(get_loss_tv(output, target))

# lsexp = torch.logsumexp(output, dim=1, keepdims=True)
# print(-torch.mean(torch.sum(torch.multiply(output - lsexp, target), dim=1)))
# criterion = torch.nn.CrossEntropyLoss()
# print(criterion(output, target))

true_class = torch.argmax(target, 1)
pred_class = torch.argmax(output, 1)
pred_result = true_class == pred_class
print('Accuracy', np.mean(pred_result.cpu().numpy()))
print('Loss CE', loss_ce)
print('Loss L2', loss_l2)
print('Loss TV', loss_tv)

### Word Embedder

In [None]:
class WordEmbedder(torch.nn.Module):

    def __init__(self, vocab_size, d_model, K):

        super().__init__()

        self.d_model = d_model

        self.embedding = torch.nn.Embedding(vocab_size, d_model)
        self.classifier = torch.nn.Linear(d_model, K)

    def forward(self, src):

        src = self.embedding(src) * np.sqrt(self.d_model)

        src = torch.mean(src, dim=1)

        pred = self.classifier(src)

        return pred


In [None]:
dataset = dataset.to(torch.device(device))
topic_mixtures = topic_mixtures.to(torch.device(device))

train_dataset = dataset[:val_idx]
val_dataset = dataset[val_idx:test_idx]
test_dataset = dataset[test_idx:]
train_mixtures = topic_mixtures[:val_idx]
val_mixtures = topic_mixtures[val_idx:test_idx]
test_mixtures = topic_mixtures[test_idx:]

def get_loss_tv(output, target):

    diff = torch.abs(output-target)

    return torch.max(diff, dim=1).values

def get_loss_l2(output, target):

    return torch.mean(torch.sum((output-target)**2, dim=1))

In [None]:
model = WordEmbedder(V, 128, 5).to(torch.device(device))

criterion = torch.nn.CrossEntropyLoss()
lr = 0.0001  # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
batch_size = 64
min_val_loss = 1e6

for epoch in range(300):

    model.train()
    total_loss = 0.
    count = 0

    num_batches = len(train_dataset) // batch_size
    for i in range(0, len(train_dataset), batch_size):
        end_idx = min(i+batch_size, len(train_dataset))
        data = train_dataset[i:end_idx, :]
        target = train_mixtures[i:end_idx, :]

        output = model(data)

        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # if count % 10 == 0:
        #     print(loss)

        count += 1

    print('Total train loss', total_loss)

    model.eval()
    total_loss = 0.
    loss_l2 = 0.
    loss_tv = 0.
    loss_ce = 0.
    count = 0
    pred_results = []

    num_batches = len(test_dataset) // batch_size
    for i in range(0, len(test_dataset), batch_size):
        end_idx = min(i+batch_size, len(test_dataset))
        data = test_dataset[i:end_idx, :]
        target = test_mixtures[i:end_idx, :]

        output = model(data)

        loss = criterion(output, target)

        total_loss += loss.item()

        # if count % 10 == 0:
        #     print(loss)

        count += 1

        true_class = torch.argmax(target, 1)
        pred_class = torch.argmax(output, 1)
        pred_result = true_class == pred_class
        pred_result = list(pred_result.cpu().numpy())
        pred_results.extend(pred_result)

        loss_ce = total_loss
        loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()
        loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()

    pred_results = np.array(pred_results)
    acc = np.mean(pred_results)

    print('Val loss CE', loss_ce/len(test_dataset))
    print('Val loss L2', loss_l2/len(test_dataset))
    print('Val loss TV', loss_tv/len(test_dataset))
    print('Accuracy', acc)



### Analysis

In [None]:
import matplotlib.pyplot as plt

idx = 0

def plot_example(idx, method):
    data = val_dataset[[idx], :-1]
    data_target = val_dataset[[idx], 1:]
    target_in = torch.concat([torch.zeros((data_target.shape[0],1), device=device, dtype=torch.int32)+START_IDX, data_target], dim=1)
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(data, target_in, device=device)

    with torch.no_grad():
        _ = model(data, target_in, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
        embd = model.doc_embd[:,token_num,:]
    #embd = model.encode(data, src_mask)[:,token_num,:]

    target = torch.squeeze(val_mixtures[[idx], :]).cpu().numpy()
    output = classifier(embd).detach().cpu()

    if method.lower() == 'lda':
        # q = D.dirichlet.Dirichlet(np.squeeze(output))
        # samp = torch.mean(torch.squeeze(q.sample([10])), dim=0).numpy()
        samp = trained_lda_mixtures[8001+idx]
    elif method.lower() == 'llm':
        samp = np.squeeze(torch.nn.functional.softmax(output, 1).numpy())

    plt.stem(samp)
    plt.bar([0,1,2,3,4], target, alpha=0.2)
    plt.title(f'{method}: datapoint {idx}')

In [None]:
plt.figure(figsize=(16,8))
plt.subplot(2,4,1)
plot_example(0, 'LDA')
plt.subplot(2,4,2)
plot_example(100, 'LDA')
plt.subplot(2,4,3)
plot_example(200, 'LDA')
plt.subplot(2,4,4)
plot_example(300, 'LDA')
plt.subplot(2,4,5)
plot_example(0, 'LLM')
plt.subplot(2,4,6)
plot_example(100, 'LLM')
plt.subplot(2,4,7)
plot_example(200, 'LLM')
plt.subplot(2,4,8)
plot_example(300, 'LLM')
plt.savefig('stickplots.pdf', format="pdf", bbox_inches="tight")
plt.show()