Libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import torch.nn as nn
import torch
import pandas as pd
import time
import math
from transformers import AutoModel, AutoTokenizer


In [4]:
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Data loader

In [6]:
class DataLoader():
    def __init__(self, filepath, num_classes, train_size, valid_size, test_size, device, batch_size, chunksize=10000):
        self.batch_size = batch_size
        self.device = device
        self.train_size = train_size
        self.valid_size = valid_size
        self.test_size = test_size
        self.mode = 'train'
        self.num_classes = num_classes

        self.dataloader = pd.read_csv(filepath, usecols=[5,8], iterator=True, chunksize=chunksize)

        model_name = 'distilbert-base-cased'
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        raw_data = None

        for index, chunk in enumerate(self.dataloader):
            if index == 0:
                raw_data = chunk
            else:
                raw_data = pd.concat([raw_data, chunk], axis=0)
        
        self.filtered = []
        sizes = []
        for i in range(num_classes):
          self.filtered.append(raw_data[raw_data['stars'] == i + 1])
          sizes.append(self.filtered[i].shape[0])
        size = min(sizes)
        self.data = self.filtered[0][:size]
        for i in range(1, num_classes):
          self.data = pd.concat([self.data, self.filtered[i][:size]], axis=0)

        self.data = self.data.sample(frac=1)
        self.data_size = self.data.shape[0]
        self.offset = valid_size + test_size


    def set_mode(self, mode):
        self.mode = mode

    def __iter__(self):
        self.ptr_idx = 0
        self.end_idx = 0
        if self.mode == 'train':
            if self.offset + self.train_size >= self.data_size:
                self.offset = self.valid_size + self.test_size
            self.ptr_idx = self.offset
            self.end_idx = self.offset + self.train_size
        elif self.mode == 'valid':
            self.ptr_idx = 0
            self.end_idx = self.valid_size
        elif self.mode == 'test':
            self.ptr_idx = self.valid_size
            self.end_idx = self.valid_size + self.test_size
        else: raise Exception("Incorrect mode chosen!!!")
        return self

    def __next__(self): 
        if self.ptr_idx + self.batch_size < self.end_idx:
            batch = self.data[self.ptr_idx : self.ptr_idx + self.batch_size]
            self.ptr_idx += self.batch_size
            return self.get_batch(batch)
        else:
            raise StopIteration

    def get_batch(self, batch):
        targets = [] # b
        emb_holder = [] # b x s x w x e
        mask_holder = [] # b x s x w
        for _, row in batch.iterrows():
            targets.append(int(row['stars']) - 1)
            embedding, mask = self.get_document(row['text'])
            emb_holder.append(embedding)
            mask_holder.append(mask)
        return self.adjust_batch(emb_holder, mask_holder), torch.tensor(targets, dtype=torch.long, device=self.device)
 
    def get_document(self, text):
        with torch.no_grad():                 
            sentences = sent_tokenize(text)
            model_inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(self.device)
            embedding = self.model(model_inputs.input_ids, attention_mask=model_inputs.attention_mask).last_hidden_state.to(self.device)
            return (embedding, model_inputs.attention_mask)
    
    def adjust_batch(self, emb_holder, mask_holder):
        s_size = emb_holder[0].size(dim=0)
        w_size = emb_holder[0].size(dim=1)
        e_size = emb_holder[0].size(dim=2)

        embeddings = None
        masks = None
        sen_masks = None
        for i in range(1, len(emb_holder)):
            if s_size < emb_holder[i].size(dim=0):
                s_size = emb_holder[i].size(dim=0)
            if w_size < emb_holder[i].size(dim=1):
                w_size = emb_holder[i].size(dim=1)

        init = True
        for emb, mask in zip(emb_holder, mask_holder):
            sen_mask = torch.cat((torch.ones(emb.size(dim=0), device=self.device),
                                  torch.zeros(s_size - emb.size(dim=0), device=self.device)), dim=0).unsqueeze(dim=0)

            emb = torch.cat((emb, torch.zeros(emb.size(dim=0), w_size - emb.size(dim=1), e_size, device=self.device)), dim=1)
            emb = torch.cat((emb, torch.zeros(s_size - emb.size(dim=0), w_size, e_size, device=self.device)), dim=0).unsqueeze(dim=0)

            mask = torch.cat((mask, torch.zeros(mask.size(dim=0), w_size - mask.size(dim=1), device=self.device)), dim=1)
            mask = torch.cat((mask, torch.zeros(s_size - mask.size(dim=0), w_size, device=self.device)), dim=0).unsqueeze(dim=0)
            if init:
                embeddings = emb
                masks = mask
                sen_masks = sen_mask
                init = False
            else:
                embeddings = torch.cat((embeddings, emb), dim=0)
                masks = torch.cat((masks, mask), dim=0)
                sen_masks = torch.cat((sen_masks, sen_mask), dim=0)
        return (embeddings, masks, sen_masks)



In [7]:
FILEPATH = "/content/drive/MyDrive/Colab/yelp_academic_dataset_review.csv"
NUM_CLASSES = 5
CHUNK_SIZE = 10000
TRAIN_SIZE= 10000
VALID_SIZE = 2000
TEST_SIZE = 2000
BATCH_SIZE = 4
dataloader = DataLoader(FILEPATH, NUM_CLASSES, TRAIN_SIZE, VALID_SIZE, TEST_SIZE, device, BATCH_SIZE, CHUNK_SIZE)

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model

In [8]:
class HCAN(nn.Module):
    def __init__(self, num_classes, dim, heads, device, word_kernel, sen_kernel, dtype=torch.float32):
        super(HCAN, self).__init__()
        padding_word = word_kernel - 2
        padding_sen = sen_kernel - 2
        self.Qa = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim,  padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )        
        self.Ka = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.Va = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.Qb = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )        
        self.Kb = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.Vb = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.Tanh()
        )

        self.Kt = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.Vt = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_word, kernel_size=word_kernel,
            device=device, dtype=dtype),
            nn.Tanh()
        )

        self.multihead_word  = nn.MultiheadAttention(embed_dim=dim, num_heads=heads,
                                                batch_first=True, device=device, dtype=dtype)
        
        self.T_word = nn.Parameter(torch.randn(1, 1, dim, device=device),  requires_grad=True)

        self.QaS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim,  padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )        
        self.KaS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.VaS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.QbS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )        
        self.KbS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.VbS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.Tanh()
        )

        self.KtS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.SELU()
        )
        self.VtS = nn.Sequential(
            nn.Conv1d(in_channels=dim, out_channels=dim, padding=padding_sen, kernel_size=sen_kernel,
            device=device, dtype=dtype),
            nn.Tanh()
        )

        self.multihead_sen  = nn.MultiheadAttention(embed_dim=dim, num_heads=heads,
                                                batch_first=True, device=device, dtype=dtype)
        
        self.T_sen = nn.Parameter(torch.randn(1, 1, dim, device=device), requires_grad=True)
        
        self.softmax = nn.Softmax(dim=2)

        self.classifier = nn.Sequential(
            nn.Linear(dim, 100, device=device),
            nn.Sigmoid(),
            nn.Linear(100, 20, device=device),
            nn.Sigmoid(),
            nn.Linear(20, num_classes, device=device),
            nn.Sigmoid()
        )

    def forward(self, input): # E b x s x l x d
        sizes = input[0].size()

        E = input[0].view(sizes[0] * sizes[1], sizes[2], sizes[3])
        E = E.transpose(2, 1)

        M = input[1].view(sizes[0] * sizes[1], sizes[2])
        M = M.unsqueeze(dim=-1)

        S = input[2].unsqueeze(dim=-1)

        Qa = torch.mul(M, self.Qa(E).transpose(1, 2))
        Ka = torch.mul(M, self.Ka(E).transpose(1, 2))
        Va = torch.mul(M, self.Va(E).transpose(1, 2))
        Qb = torch.mul(M, self.Qb(E).transpose(1, 2))
        Kb = torch.mul(M, self.Kb(E).transpose(1, 2))
        Vb = torch.mul(M, self.Vb(E).transpose(1, 2))
        A = torch.mul(M, self.multihead_word(Qa, Ka, Va)[0])
        B = torch.mul(M, self.multihead_word(Qb, Kb, Vb)[0])
        Eout = torch.mul(A, B).transpose(2, 1)
        Kt = torch.mul(M, self.Kt(Eout).transpose(1, 2))
        Vt = torch.mul(M, self.Vt(Eout).transpose(1, 2))

        D = self.multihead_word(self.T_word.expand(Kt.size(dim=0), 1, Kt.size(dim=2)), Kt, Vt)[0].squeeze(dim=1)
        
        D = D.view(sizes[0], sizes[1], sizes[3])
        D = torch.mul(S, D)
        D = D.transpose(2, 1)
        QaS = torch.mul(S, self.QaS(D).transpose(1, 2))
        KaS = torch.mul(S, self.KaS(D).transpose(1, 2))
        VaS = torch.mul(S, self.VaS(D).transpose(1, 2))
        QbS = torch.mul(S, self.QbS(D).transpose(1, 2))
        KbS = torch.mul(S, self.KbS(D).transpose(1, 2))
        VbS = torch.mul(S, self.VbS(D).transpose(1, 2))
        AS = torch.mul(S, self.multihead_sen(QaS, KaS, VaS)[0])
        BS = torch.mul(S, self.multihead_sen(QbS, KbS, VbS)[0])
        Dout = torch.mul(AS, BS).transpose(1, 2)
        KtS = torch.mul(S, self.KtS(Dout).transpose(1, 2))
        VtS = torch.mul(S, self.VtS(Dout).transpose(1, 2))

        doc = self.multihead_sen(self.T_sen.expand(KtS.size(dim=0), 1, KtS.size(dim=2)), KtS, VtS)[0] # b x 1 x e

        return self.softmax(self.classifier(self.softmax(doc))).squeeze(dim=1)

Hyperparameters

In [9]:
EMBEDDING_SIZE = 768
NUM_HEADS = 8
WINDOW_SIZE = 3
SEN_WINDOW_SIZE = 3
hcan = HCAN(NUM_CLASSES, EMBEDDING_SIZE, NUM_HEADS, device, WINDOW_SIZE, SEN_WINDOW_SIZE)

In [10]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

hcan.apply(initialize_weights)
hcan.load_state_dict(torch.load('/content/drive/MyDrive/my-model-6.pt'))

<All keys matched successfully>

In [11]:
def percentage(model, data):
  guessed = [0 for i in range(5)]
  model.eval()
  with torch.no_grad():
    inputs = data.get_examples(20)
    for star, input in enumerate(inputs):
      for x in input:
        if star == torch.argmax(model.forward(x)).item():
          guessed[star] += 1
  print(f"Accuracy: {sum(guessed)}%")
  for i in range(5):
    print(f"{i} star: {round(guessed[i]/len(inputs[i])*100)}%")

In [12]:
def train(model, dataloader, optimizer, criterion):
    model.train()
    
    epoch_loss = 0
    dataloader.set_mode('train')
    size = dataloader.train_size/dataloader.batch_size
    count = 0
    blyat = 0
    step = math.floor(size/100)
    prev = 0
    for input, target in dataloader:
        if count >= step:
          msg = f"Train processed: {round((blyat*count)/size*100)}%"
          if blyat != 0:
            print("\b"*prev)
          prev = len(msg) + 2
          blyat += 1
          print(msg)
          count = 0
        count += 1
        optimizer.zero_grad()
        predicted = model.forward(input)
        loss = criterion(predicted, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print("\b"*prev)
        
    return epoch_loss

In [13]:
def evaluate(model, dataloader, criterion):
    model.eval()
    
    epoch_loss = 0
    dataloader.set_mode('valid')
    size = dataloader.valid_size/dataloader.batch_size
    count = 0
    blyat = 0
    step = math.floor(size/100)
    prev = 0
    with torch.no_grad():
        for input, target in dataloader:
            if count >= step:
              msg = f"Validation processed: {round((blyat*count)/size*100)}%"
              if blyat != 0:
                print("\b"*prev)
              prev = len(msg) + 2
              blyat += 1
              print(msg)
              count = 0
            count += 1
            predicted = model.forward(input)
            loss = criterion(predicted, target)
            epoch_loss += loss.item()
    print("\b"*prev)
        
    return epoch_loss

In [14]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
EPOCHS = 100
LEARNING_RATE = 0.01
optimizer = torch.optim.SGD(hcan.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
for epoch in range(EPOCHS):

    start_time = time.time()

    train_loss = train(hcan, dataloader, optimizer, criterion)
    valid_loss = evaluate(hcan, dataloader, criterion)

    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    torch.save(hcan.state_dict(), '/content/drive/MyDrive/my-model-{}.pt'.format(epoch + 1))
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')
    #percentage(hcan, dataloader)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Validation processed: 78%

Validation processed: 79%

Validation processed: 80%

Validation processed: 81%

Validation processed: 82%

Validation processed: 83%

Validation processed: 84%

Validation processed: 85%

Validation processed: 86%

Validation processed: 87%

Validation processed: 88%

Validation processed: 89%

Validation processed: 90%

Validation processed: 91%

Validation processed: 92%

Validation processed: 93%

Validation processed: 94%

Validation proces